mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Compare commits
41 Commits
meshgrid-f
...
llama2-was
Author | SHA1 | Date | |
---|---|---|---|
b97463098c | |||
fbd69f952c | |||
6c990a33ea | |||
1704f1b3ae | |||
693fad511c | |||
36fb84f038 | |||
c12ad45562 | |||
7d0202710b | |||
392a00a147 | |||
4c967b9184 | |||
c05c0a8213 | |||
969960847a | |||
5fc66bd4ba | |||
174b208052 | |||
154c674a79 | |||
7bbde55c61 | |||
c3f2676d49 | |||
46d6566c99 | |||
55bc3382cf | |||
dece37c6f4 | |||
498c50348c | |||
012ae0090e | |||
95a857cf57 | |||
612f5b8156 | |||
ef33df7ae2 | |||
c8face3f95 | |||
85bea43e5b | |||
b3181455d5 | |||
e2826e70b3 | |||
916619f70b | |||
9b1158b315 | |||
70d06ab4b0 | |||
0ec5ebcec4 | |||
c8e197f68c | |||
5f20697918 | |||
e37b487767 | |||
e5dc8cb4f4 | |||
e7b886d56f | |||
6a446d9d73 | |||
0acd16751d | |||
c698e17619 |
BIN
.github/workflows/maturin.yml
vendored
Normal file
BIN
.github/workflows/maturin.yml
vendored
Normal file
Binary file not shown.
@ -7,13 +7,7 @@ members = [
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
"candle-wasm-examples/llama2-c",
|
||||
"candle-wasm-examples/segment-anything",
|
||||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
"candle-wasm-examples/phi",
|
||||
"candle-wasm-examples/t5",
|
||||
"candle-wasm-examples/*",
|
||||
"candle-wasm-tests",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
|
10
README.md
10
README.md
@ -56,6 +56,7 @@ These online demos run entirely in your browser:
|
||||
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
|
||||
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
|
||||
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
|
||||
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
@ -95,12 +96,15 @@ We also provide a some command line based examples using state of the art models
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
using self-supervision (can be used for imagenet classification, depth
|
||||
evaluation, segmentation).
|
||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||
generate captions for an image.
|
||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||
model, generates the translated text from the input text.
|
||||
|
||||
Run them using commands like:
|
||||
```
|
||||
@ -135,6 +139,8 @@ And then head over to
|
||||
## Useful External Resources
|
||||
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
|
||||
very detailed tutorial showing how to convert a PyTorch model to Candle.
|
||||
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
|
||||
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
|
||||
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
|
||||
that conforms to the official `peft` implementation.
|
||||
|
||||
@ -170,6 +176,8 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- Text to text.
|
||||
- Marian MT (Machine Translation).
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||
- yolo-v3, yolo-v8.
|
||||
|
@ -238,6 +238,13 @@ impl Tensor {
|
||||
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
|
||||
.transpose(0, 1)?;
|
||||
let sum_grad = grads.or_insert(kernel)?;
|
||||
let (_, _, k0, k1) = kernel.dims4()?;
|
||||
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
|
||||
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
|
||||
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
|
||||
} else {
|
||||
grad_kernel
|
||||
};
|
||||
*sum_grad = sum_grad.add(&grad_kernel)?;
|
||||
}
|
||||
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
|
||||
|
@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||
};
|
||||
let dim = self.dim;
|
||||
let n_ids = match self.ids_l.dims() {
|
||||
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "gather" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
};
|
||||
for left_i in 0..ids_left_len {
|
||||
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
|
||||
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
let dim = self.dim;
|
||||
@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::U32(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
Self::I64(ids) => {
|
||||
let ids = match ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &ids[a..b],
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" })?,
|
||||
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
IndexAdd { ids, dim }.map(self, l, src, src_l)
|
||||
}
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,11 +185,17 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_uniform(shape, dtype, lo, up)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_uniform<T: crate::FloatDType>(
|
||||
&self,
|
||||
@ -213,11 +219,17 @@ impl Device {
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
|
||||
if dtype == DType::F16 || dtype == DType::BF16 {
|
||||
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
|
||||
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
|
||||
} else {
|
||||
let storage = device.rand_normal(shape, dtype, mean, std)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rand_normal<T: crate::FloatDType>(
|
||||
&self,
|
||||
|
@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
self(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<M: Module> ModuleT for M {
|
||||
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
|
||||
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
|
||||
unary_op!(Abs, "abs", v, v.abs());
|
||||
unary_op!(Neg, "neg", v, -v);
|
||||
unary_op!(Recip, "recip", v, v.recip());
|
||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||
@ -666,6 +665,40 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
const KERNEL: &'static str = "uabs";
|
||||
const V: Self = Abs;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.abs()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v.abs()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
@ -887,6 +920,10 @@ impl BackpropOp {
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn is_none(&self) -> bool {
|
||||
self.0.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
|
@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut acc = _mm256_setzero_ps();
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
|
@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {
|
||||
|
||||
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
// Generic implementation.
|
||||
let mut sumf = 0f32;
|
||||
for (xs, ys) in xs.iter().zip(ys.iter()) {
|
||||
|
@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
for i in 0..nb {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let m4b = vdupq_n_u8(0x0F);
|
||||
let s8b = vdupq_n_s8(0x8);
|
||||
|
||||
let v0_0 = vld1q_u8(x0.qs.as_ptr());
|
||||
let v0_1 = vld1q_u8(x1.qs.as_ptr());
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
|
||||
// sub 8
|
||||
let v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
let v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
let v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
let v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// load y
|
||||
let v1_0l = vld1q_s8(y0.qs.as_ptr());
|
||||
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let v1_1l = vld1q_s8(y1.qs.as_ptr());
|
||||
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO: Support dotprod when it's available outside of nightly.
|
||||
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
|
||||
@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
|
||||
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
|
||||
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
|
||||
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
||||
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,28 +69,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut sumv0 = vdupq_n_f32(0.0f32);
|
||||
let mut sumv1 = vdupq_n_f32(0.0f32);
|
||||
for i in (0..nb).step_by(2) {
|
||||
for i in 0..nb {
|
||||
let x0 = &xs[i];
|
||||
let x1 = &xs[i + 1];
|
||||
let y0 = &ys[i];
|
||||
let y1 = &ys[i + 1];
|
||||
|
||||
let x0_0 = vld1q_s8(x0.qs.as_ptr());
|
||||
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
|
||||
let x1_0 = vld1q_s8(x1.qs.as_ptr());
|
||||
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
|
||||
|
||||
// load y
|
||||
let y0_0 = vld1q_s8(y0.qs.as_ptr());
|
||||
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
|
||||
let y1_0 = vld1q_s8(y1.qs.as_ptr());
|
||||
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
|
||||
|
||||
// TODO dotprod once this is the intrinsics are.
|
||||
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
|
||||
@ -123,28 +88,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
|
||||
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
|
||||
|
||||
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
|
||||
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
|
||||
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
|
||||
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
|
||||
|
||||
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
|
||||
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
|
||||
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
|
||||
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
|
||||
|
||||
sumv0 = vmlaq_n_f32(
|
||||
sumv0,
|
||||
vcvtq_f32_s32(vaddq_s32(p0, p1)),
|
||||
x0.d.to_f32() * y0.d.to_f32(),
|
||||
);
|
||||
sumv1 = vmlaq_n_f32(
|
||||
sumv1,
|
||||
vcvtq_f32_s32(vaddq_s32(p2, p3)),
|
||||
x1.d.to_f32() * y1.d.to_f32(),
|
||||
);
|
||||
}
|
||||
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
|
||||
Ok(vaddvq_f32(sumv0))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
@ -61,10 +57,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
|
@ -203,7 +203,7 @@ impl Shape {
|
||||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
|
@ -385,12 +385,22 @@ impl Tensor {
|
||||
step: D,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
if D::is_zero(&step) {
|
||||
crate::bail!("step cannot be zero")
|
||||
}
|
||||
let mut data = vec![];
|
||||
let mut current = start;
|
||||
if step >= D::zero() {
|
||||
while current < end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
} else {
|
||||
while current > end {
|
||||
data.push(current);
|
||||
current += step;
|
||||
}
|
||||
}
|
||||
let len = data.len();
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
@ -1186,14 +1196,16 @@ impl Tensor {
|
||||
op: "scatter-add (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
self.layout(),
|
||||
@ -1265,7 +1277,8 @@ impl Tensor {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
@ -1299,7 +1312,8 @@ impl Tensor {
|
||||
op: "index-add (self, source)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
// The number of element in indexes must match the dimension on which the add is
|
||||
// performed on the source tensor (and the index values from `indexes` are taken from
|
||||
@ -1310,7 +1324,8 @@ impl Tensor {
|
||||
op: "index-add (ids, source))",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().index_add(
|
||||
self.layout(),
|
||||
@ -1358,7 +1373,8 @@ impl Tensor {
|
||||
op: "gather",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: indexes.shape().clone(),
|
||||
})?
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage =
|
||||
self.storage()
|
||||
@ -1791,7 +1807,12 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
@ -1803,6 +1824,7 @@ impl Tensor {
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
}
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||
@ -2265,6 +2287,11 @@ impl Tensor {
|
||||
m.forward(self)
|
||||
}
|
||||
|
||||
/// Run the `forward` method of `m` on `self`.
|
||||
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
|
||||
m.forward_t(self, train)
|
||||
}
|
||||
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
|
||||
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
|
||||
let loss = res.sqr()?.sum_all()?;
|
||||
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
|
||||
let grads = loss.backward()?;
|
||||
let grad_t = grads.get(&t).unwrap();
|
||||
let grad_w = grads.get(&w).unwrap();
|
||||
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
|
||||
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[9.29, -7.03, 7.87, 0.0, 0.0],
|
||||
[-1.8, -7.82, 5.9, 0.0, 0.0],
|
||||
[-3.12, 4.49, 5.52, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[21.73, 3.39, 4.77, 0.0, 0.0],
|
||||
[8.25, 3.73, 27.61, 0.0, 0.0],
|
||||
[-20.55, -5.61, -2.77, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[-8.98, 9.91, -7.15, 0.0, 0.0],
|
||||
[4.93, -0.33, 4.56, 0.0, 0.0],
|
||||
[-6.7, -5.76, -8.05, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
],
|
||||
[
|
||||
[23.54, 6.98, -10.0, 0.0, 0.0],
|
||||
[9.65, 6.18, 18.72, 0.0, 0.0],
|
||||
[3.29, -5.27, 0.79, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
|
||||
[
|
||||
[
|
||||
[-3.47, 7.44, 0.66],
|
||||
[12.89, -3.4, -9.29],
|
||||
[-14.16, -0.83, 7.14]
|
||||
],
|
||||
[
|
||||
[-3.23, 5.37, -3.02],
|
||||
[-2.12, -11.24, 1.94],
|
||||
[6.97, 7.2, 2.99]
|
||||
],
|
||||
[
|
||||
[-4.04, -3.31, 4.87],
|
||||
[-6.68, -5.68, 1.73],
|
||||
[-5.54, 4.32, 0.52]
|
||||
],
|
||||
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn arange(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
|
||||
[0, 1, 2, 3, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
|
||||
[0, 2, 4],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
|
||||
[0, 3],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
|
||||
[5, 4, 3, 2, 1],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1037,6 +1056,7 @@ fn randn(device: &Device) -> Result<()> {
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||
test_device!(ones, ones_cpu, ones_gpu);
|
||||
test_device!(arange, arange_cpu, arange_gpu);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu);
|
||||
@ -1089,3 +1109,11 @@ fn pad_with_same() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn i64_abs() -> Result<()> {
|
||||
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
|
||||
let t = t.abs()?;
|
||||
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> {
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
||||
|
45
candle-examples/examples/jina-bert/README.md
Normal file
45
candle-examples/examples/jina-bert/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# candle-jina-bert
|
||||
|
||||
Jina-Bert is a general large language model with a context size of 8192, [model
|
||||
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
|
||||
it can be used for two different tasks:
|
||||
- Compute sentence embeddings for a prompt.
|
||||
- Compute similarities between a set of sentences.
|
||||
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
|
||||
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
|
||||
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
|
||||
> ...
|
||||
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
|
||||
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
|
||||
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
```
|
||||
|
||||
## Similarities
|
||||
|
||||
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
|
||||
sentences (hardcoded in the examples). Then cosine similarities are computed for
|
||||
each sentence pair and they are reported by decreasing values, hence the first
|
||||
reported pair contains the two sentences that have the highest similarity score.
|
||||
The sentence embeddings are computed using average pooling through all the
|
||||
sentence tokens, including some potential padding.
|
||||
|
||||
```bash
|
||||
cargo run --example jina-bert --release
|
||||
|
||||
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
|
||||
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
|
||||
> score: 0.78 'I love pasta' 'Do you like pizza?'
|
||||
> score: 0.68 'I love pasta' 'The new movie is awesome'
|
||||
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
|
||||
```
|
180
candle-examples/examples/jina-bert/main.rs
Normal file
180
candle-examples/examples/jina-bert/main.rs
Normal file
@ -0,0 +1,180 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle_transformers::models::jina_bert::{BertModel, Config};
|
||||
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
let model = match &self.model {
|
||||
Some(model_file) => std::path::PathBuf::from(model_file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"jinaai/jina-embeddings-v2-base-en".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let tokenizer = match &self.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => Api::new()?
|
||||
.repo(Repo::new(
|
||||
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
RepoType::Model,
|
||||
))
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let config = Config::v2_base();
|
||||
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = BertModel::new(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
if let Some(prompt) = args.prompt {
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
println!("Loaded and encoded {:?}", start.elapsed());
|
||||
for idx in 0..args.n {
|
||||
let start = std::time::Instant::now();
|
||||
let ys = model.forward(&token_ids)?;
|
||||
if idx == 0 {
|
||||
println!("{ys}");
|
||||
}
|
||||
println!("Took {:?}", start.elapsed());
|
||||
}
|
||||
} else {
|
||||
let sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
];
|
||||
let n_sentences = sentences.len();
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = tokenizers::PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
let tokens = tokenizer
|
||||
.encode_batch(sentences.to_vec(), true)
|
||||
.map_err(E::msg)?;
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), device)
|
||||
})
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0)?;
|
||||
println!("running inference on batch {:?}", token_ids.shape());
|
||||
let embeddings = model.forward(&token_ids)?;
|
||||
println!("generated embeddings {:?}", embeddings.shape());
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
|
||||
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
|
||||
let embeddings = if args.normalize_embeddings {
|
||||
normalize_l2(&embeddings)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
println!("pooled embeddings {:?}", embeddings.shape());
|
||||
|
||||
let mut similarities = vec![];
|
||||
for i in 0..n_sentences {
|
||||
let e_i = embeddings.get(i)?;
|
||||
for j in (i + 1)..n_sentences {
|
||||
let e_j = embeddings.get(j)?;
|
||||
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
|
||||
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
|
||||
similarities.push((cosine_similarity, i, j))
|
||||
}
|
||||
}
|
||||
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
|
||||
for &(score, i, j) in similarities[..5].iter() {
|
||||
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
@ -6,9 +6,10 @@ extern crate accelerate_src;
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod model;
|
||||
use candle_transformers::models::llama2_c as model;
|
||||
use candle_transformers::models::llama2_c_weights as weights;
|
||||
use candle_transformers::models::quantized_llama2_c as qmodel;
|
||||
mod training;
|
||||
mod weights;
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
@ -19,6 +20,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Llama(Llama),
|
||||
QLlama(QLlama),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
@ -241,24 +257,66 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (vb, config) = if is_safetensors {
|
||||
let config = Config::tiny();
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
.dims2()?;
|
||||
let config = match dim {
|
||||
64 => Config::tiny_260k(),
|
||||
288 => Config::tiny_15m(),
|
||||
512 => Config::tiny_42m(),
|
||||
768 => Config::tiny_110m(),
|
||||
_ => anyhow::bail!("no config for dim {dim}"),
|
||||
};
|
||||
let freq_cis_real = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
("freq_cis_real".to_string(), freq_cis_real),
|
||||
("freq_cis_imag".to_string(), freq_cis_imag),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
(vb, config)
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
println!("{config:?}");
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
(vb, config)
|
||||
};
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
};
|
||||
|
||||
println!("starting the inference loop");
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
|
||||
@ -273,7 +331,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
if tokens.len() >= model.config.seq_len {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
|
@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
);
|
||||
let varmap = candle_nn::VarMap::new();
|
||||
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
|
||||
let config = Config::tiny();
|
||||
let config = Config::tiny_15m();
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
|
38
candle-examples/examples/marian-mt/README.md
Normal file
38
candle-examples/examples/marian-mt/README.md
Normal file
@ -0,0 +1,38 @@
|
||||
# candle-marian-mt
|
||||
|
||||
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||
translate text from French to English. See the associated [model
|
||||
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||
the model itself.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
cargo run --example marian-mt --release -- \
|
||||
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||
```
|
||||
|
||||
```
|
||||
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||
I know you are waiting for me. I will go through the forest, I will go through the
|
||||
mountain. I cannot stay far from you any longer.</s>
|
||||
```
|
||||
|
||||
## Generating the tokenizer.json files
|
||||
|
||||
You can use the following script to generate the `tokenizer.json` config files
|
||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||
directory.
|
||||
|
||||
```python
|
||||
from convert_slow_tokenizer import MarianConverter
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||
```
|
1385
candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Normal file
1385
candle-examples/examples/marian-mt/convert_slow_tokenizer.py
Normal file
File diff suppressed because it is too large
Load Diff
152
candle-examples/examples/marian-mt/main.rs
Normal file
152
candle-examples/examples/marian-mt/main.rs
Normal file
@ -0,0 +1,152 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Big,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_dec: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Text to be translated
|
||||
#[arg(long)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
let mut model = marian::MTModel::new(&config, vb)?;
|
||||
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let encoder_xs = {
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.text, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.push(config.eos_token_id);
|
||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index >= 1 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
token_ids.push(token);
|
||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
||||
use std::io::Write;
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
println!();
|
||||
Ok(())
|
||||
}
|
@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
@ -95,7 +95,7 @@ impl ConvNet {
|
||||
.flatten_from(1)?
|
||||
.apply(&self.fc1)?
|
||||
.relu()?;
|
||||
self.dropout.forward(&xs, train)?.apply(&self.fc2)
|
||||
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,6 +124,7 @@ enum WhichModel {
|
||||
#[value(name = "1.5")]
|
||||
V1_5,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -224,7 +225,9 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -238,7 +241,7 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::PuffinPhiV2 => "main".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -248,7 +251,9 @@ fn main() -> Result<()> {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
},
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
@ -259,11 +264,13 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
|
||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
|
||||
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
|
||||
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -276,6 +283,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
|
451
candle-examples/examples/reinforcement-learning/ddpg.rs
Normal file
451
candle-examples/examples/reinforcement-learning/ddpg.rs
Normal file
@ -0,0 +1,451 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::fmt::Display;
|
||||
|
||||
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
|
||||
use candle_nn::{
|
||||
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||
VarBuilder, VarMap,
|
||||
};
|
||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||
|
||||
pub struct OuNoise {
|
||||
mu: f64,
|
||||
theta: f64,
|
||||
sigma: f64,
|
||||
state: Tensor,
|
||||
}
|
||||
impl OuNoise {
|
||||
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
|
||||
Ok(Self {
|
||||
mu,
|
||||
theta,
|
||||
sigma,
|
||||
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sample(&mut self) -> Result<Tensor> {
|
||||
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
|
||||
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
|
||||
self.state = (&self.state + dx)?;
|
||||
Ok(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Transition {
|
||||
state: Tensor,
|
||||
action: Tensor,
|
||||
reward: Tensor,
|
||||
next_state: Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
}
|
||||
impl Transition {
|
||||
fn new(
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: state.clone(),
|
||||
action: action.clone(),
|
||||
reward: reward.clone(),
|
||||
next_state: next_state.clone(),
|
||||
terminated,
|
||||
truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReplayBuffer {
|
||||
buffer: VecDeque<Transition>,
|
||||
capacity: usize,
|
||||
size: usize,
|
||||
}
|
||||
impl ReplayBuffer {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: VecDeque::with_capacity(capacity),
|
||||
capacity,
|
||||
size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(
|
||||
&mut self,
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) {
|
||||
if self.size == self.capacity {
|
||||
self.buffer.pop_front();
|
||||
} else {
|
||||
self.size += 1;
|
||||
}
|
||||
self.buffer.push_back(Transition::new(
|
||||
state, action, reward, next_state, terminated, truncated,
|
||||
));
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn random_batch(
|
||||
&self,
|
||||
batch_size: usize,
|
||||
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
|
||||
if self.size < batch_size {
|
||||
Ok(None)
|
||||
} else {
|
||||
let transitions: Vec<&Transition> = thread_rng()
|
||||
.sample_iter(Uniform::from(0..self.size))
|
||||
.take(batch_size)
|
||||
.map(|i| self.buffer.get(i).unwrap())
|
||||
.collect();
|
||||
|
||||
let states: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.state.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let actions: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.action.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let rewards: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.reward.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let next_states: Vec<Tensor> = transitions
|
||||
.iter()
|
||||
.map(|t| t.next_state.unsqueeze(0))
|
||||
.collect::<Result<_>>()?;
|
||||
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
|
||||
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
|
||||
|
||||
Ok(Some((
|
||||
Tensor::cat(&states, 0)?,
|
||||
Tensor::cat(&actions, 0)?,
|
||||
Tensor::cat(&rewards, 0)?,
|
||||
Tensor::cat(&next_states, 0)?,
|
||||
terminateds,
|
||||
truncateds,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn track(
|
||||
varmap: &mut VarMap,
|
||||
vb: &VarBuilder,
|
||||
target_prefix: &str,
|
||||
network_prefix: &str,
|
||||
dims: &[(usize, usize)],
|
||||
tau: f64,
|
||||
) -> Result<()> {
|
||||
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
|
||||
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
|
||||
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
|
||||
varmap.set_one(
|
||||
format!("{target_prefix}-fc{i}.weight"),
|
||||
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
|
||||
)?;
|
||||
|
||||
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
|
||||
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
|
||||
varmap.set_one(
|
||||
format!("{target_prefix}-fc{i}.bias"),
|
||||
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct Actor<'a> {
|
||||
varmap: VarMap,
|
||||
vb: VarBuilder<'a>,
|
||||
network: Sequential,
|
||||
target_network: Sequential,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
dims: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl Actor<'_> {
|
||||
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||
let mut varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||
|
||||
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
|
||||
|
||||
let make_network = |prefix: &str| {
|
||||
let seq = seq()
|
||||
.add(linear(
|
||||
dims[0].0,
|
||||
dims[0].1,
|
||||
vb.pp(format!("{prefix}-fc0")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[1].0,
|
||||
dims[1].1,
|
||||
vb.pp(format!("{prefix}-fc1")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[2].0,
|
||||
dims[2].1,
|
||||
vb.pp(format!("{prefix}-fc2")),
|
||||
)?)
|
||||
.add(func(|xs| xs.tanh()));
|
||||
Ok::<Sequential, Error>(seq)
|
||||
};
|
||||
|
||||
let network = make_network("actor")?;
|
||||
let target_network = make_network("target-actor")?;
|
||||
|
||||
// this sets the two networks to be equal to each other using tau = 1.0
|
||||
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
|
||||
|
||||
Ok(Self {
|
||||
varmap,
|
||||
vb,
|
||||
network,
|
||||
target_network,
|
||||
size_state,
|
||||
size_action,
|
||||
dims,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||
self.network.forward(state)
|
||||
}
|
||||
|
||||
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
|
||||
self.target_network.forward(state)
|
||||
}
|
||||
|
||||
fn track(&mut self, tau: f64) -> Result<()> {
|
||||
track(
|
||||
&mut self.varmap,
|
||||
&self.vb,
|
||||
"target-actor",
|
||||
"actor",
|
||||
&self.dims,
|
||||
tau,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct Critic<'a> {
|
||||
varmap: VarMap,
|
||||
vb: VarBuilder<'a>,
|
||||
network: Sequential,
|
||||
target_network: Sequential,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
dims: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl Critic<'_> {
|
||||
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
|
||||
let mut varmap = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
|
||||
|
||||
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
|
||||
|
||||
let make_network = |prefix: &str| {
|
||||
let seq = seq()
|
||||
.add(linear(
|
||||
dims[0].0,
|
||||
dims[0].1,
|
||||
vb.pp(format!("{prefix}-fc0")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[1].0,
|
||||
dims[1].1,
|
||||
vb.pp(format!("{prefix}-fc1")),
|
||||
)?)
|
||||
.add(Activation::Relu)
|
||||
.add(linear(
|
||||
dims[2].0,
|
||||
dims[2].1,
|
||||
vb.pp(format!("{prefix}-fc2")),
|
||||
)?);
|
||||
Ok::<Sequential, Error>(seq)
|
||||
};
|
||||
|
||||
let network = make_network("critic")?;
|
||||
let target_network = make_network("target-critic")?;
|
||||
|
||||
// this sets the two networks to be equal to each other using tau = 1.0
|
||||
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
|
||||
|
||||
Ok(Self {
|
||||
varmap,
|
||||
vb,
|
||||
network,
|
||||
target_network,
|
||||
size_state,
|
||||
size_action,
|
||||
dims,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||
let xs = Tensor::cat(&[action, state], 1)?;
|
||||
self.network.forward(&xs)
|
||||
}
|
||||
|
||||
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
|
||||
let xs = Tensor::cat(&[action, state], 1)?;
|
||||
self.target_network.forward(&xs)
|
||||
}
|
||||
|
||||
fn track(&mut self, tau: f64) -> Result<()> {
|
||||
track(
|
||||
&mut self.varmap,
|
||||
&self.vb,
|
||||
"target-critic",
|
||||
"critic",
|
||||
&self.dims,
|
||||
tau,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct DDPG<'a> {
|
||||
actor: Actor<'a>,
|
||||
actor_optim: AdamW,
|
||||
critic: Critic<'a>,
|
||||
critic_optim: AdamW,
|
||||
gamma: f64,
|
||||
tau: f64,
|
||||
replay_buffer: ReplayBuffer,
|
||||
ou_noise: OuNoise,
|
||||
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
pub train: bool,
|
||||
}
|
||||
|
||||
impl DDPG<'_> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
device: &Device,
|
||||
size_state: usize,
|
||||
size_action: usize,
|
||||
train: bool,
|
||||
actor_lr: f64,
|
||||
critic_lr: f64,
|
||||
gamma: f64,
|
||||
tau: f64,
|
||||
buffer_capacity: usize,
|
||||
ou_noise: OuNoise,
|
||||
) -> Result<Self> {
|
||||
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
|
||||
varmap
|
||||
.data()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
|
||||
.collect::<Vec<Var>>()
|
||||
};
|
||||
|
||||
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
|
||||
let actor_optim = AdamW::new(
|
||||
filter_by_prefix(&actor.varmap, "actor"),
|
||||
ParamsAdamW {
|
||||
lr: actor_lr,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
|
||||
let critic_optim = AdamW::new(
|
||||
filter_by_prefix(&critic.varmap, "critic"),
|
||||
ParamsAdamW {
|
||||
lr: critic_lr,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
actor,
|
||||
actor_optim,
|
||||
critic,
|
||||
critic_optim,
|
||||
gamma,
|
||||
tau,
|
||||
replay_buffer: ReplayBuffer::new(buffer_capacity),
|
||||
ou_noise,
|
||||
size_state,
|
||||
size_action,
|
||||
train,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remember(
|
||||
&mut self,
|
||||
state: &Tensor,
|
||||
action: &Tensor,
|
||||
reward: &Tensor,
|
||||
next_state: &Tensor,
|
||||
terminated: bool,
|
||||
truncated: bool,
|
||||
) {
|
||||
self.replay_buffer
|
||||
.push(state, action, reward, next_state, terminated, truncated)
|
||||
}
|
||||
|
||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||
let actions = self
|
||||
.actor
|
||||
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||
.squeeze(0)?;
|
||||
let actions = if self.train {
|
||||
(actions + self.ou_noise.sample()?)?
|
||||
} else {
|
||||
actions
|
||||
};
|
||||
actions.squeeze(0)?.to_scalar::<f32>()
|
||||
}
|
||||
|
||||
pub fn train(&mut self, batch_size: usize) -> Result<()> {
|
||||
let (states, actions, rewards, next_states, _, _) =
|
||||
match self.replay_buffer.random_batch(batch_size)? {
|
||||
Some(v) => v,
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
let q_target = self
|
||||
.critic
|
||||
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
|
||||
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
|
||||
let q = self.critic.forward(&states, &actions)?;
|
||||
let diff = (q_target - q)?;
|
||||
|
||||
let critic_loss = diff.sqr()?.mean_all()?;
|
||||
self.critic_optim.backward_step(&critic_loss)?;
|
||||
|
||||
let actor_loss = self
|
||||
.critic
|
||||
.forward(&states, &self.actor.forward(&states)?)?
|
||||
.mean_all()?
|
||||
.neg()?;
|
||||
self.actor_optim.backward_step(&actor_loss)?;
|
||||
|
||||
self.critic.track(self.tau)?;
|
||||
self.actor.track(self.tau)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -7,20 +7,22 @@ use pyo3::types::PyDict;
|
||||
/// The return value for a step.
|
||||
#[derive(Debug)]
|
||||
pub struct Step<A> {
|
||||
pub obs: Tensor,
|
||||
pub state: Tensor,
|
||||
pub action: A,
|
||||
pub reward: f64,
|
||||
pub is_done: bool,
|
||||
pub terminated: bool,
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
impl<A: Copy> Step<A> {
|
||||
/// Returns a copy of this step changing the observation tensor.
|
||||
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
|
||||
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
|
||||
Step {
|
||||
obs: obs.clone(),
|
||||
state: state.clone(),
|
||||
action: self.action,
|
||||
reward: self.reward,
|
||||
is_done: self.is_done,
|
||||
terminated: self.terminated,
|
||||
truncated: self.truncated,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -63,14 +65,14 @@ impl GymEnv {
|
||||
|
||||
/// Resets the environment, returning the observation tensor.
|
||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||
let obs: Vec<f32> = Python::with_gil(|py| {
|
||||
let state: Vec<f32> = Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("seed", seed)?;
|
||||
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
obs.as_ref(py).get_item(0)?.extract()
|
||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
||||
state.as_ref(py).get_item(0)?.extract()
|
||||
})
|
||||
.map_err(w)?;
|
||||
Tensor::new(obs, &Device::Cpu)
|
||||
Tensor::new(state, &Device::Cpu)
|
||||
}
|
||||
|
||||
/// Applies an environment step using the specified action.
|
||||
@ -78,21 +80,23 @@ impl GymEnv {
|
||||
&self,
|
||||
action: A,
|
||||
) -> Result<Step<A>> {
|
||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
||||
let step = step.as_ref(py);
|
||||
let obs: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||
let reward: f64 = step.get_item(1)?.extract()?;
|
||||
let is_done: bool = step.get_item(2)?.extract()?;
|
||||
Ok((obs, reward, is_done))
|
||||
let terminated: bool = step.get_item(2)?.extract()?;
|
||||
let truncated: bool = step.get_item(3)?.extract()?;
|
||||
Ok((state, reward, terminated, truncated))
|
||||
})
|
||||
.map_err(w)?;
|
||||
let obs = Tensor::new(obs, &Device::Cpu)?;
|
||||
let state = Tensor::new(state, &Device::Cpu)?;
|
||||
Ok(Step {
|
||||
obs,
|
||||
reward,
|
||||
is_done,
|
||||
state,
|
||||
action,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -9,14 +9,34 @@ extern crate accelerate_src;
|
||||
mod gym_env;
|
||||
mod vec_gym_env;
|
||||
|
||||
use candle::Result;
|
||||
mod ddpg;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
|
||||
// The impact of the q value of the next state on the current state's q value.
|
||||
const GAMMA: f64 = 0.99;
|
||||
// The weight for updating the target networks.
|
||||
const TAU: f64 = 0.005;
|
||||
// The capacity of the replay buffer used for sampling training data.
|
||||
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
|
||||
// The training batch size for each training iteration.
|
||||
const TRAINING_BATCH_SIZE: usize = 100;
|
||||
// The total number of episodes.
|
||||
const MAX_EPISODES: usize = 100;
|
||||
// The maximum length of an episode.
|
||||
const EPISODE_LENGTH: usize = 200;
|
||||
// The number of training iterations after one episode finishes.
|
||||
const TRAINING_ITERATIONS: usize = 200;
|
||||
|
||||
// Ornstein-Uhlenbeck process parameters.
|
||||
const MU: f64 = 0.0;
|
||||
const THETA: f64 = 0.15;
|
||||
const SIGMA: f64 = 0.1;
|
||||
|
||||
const ACTOR_LEARNING_RATE: f64 = 1e-4;
|
||||
const CRITIC_LEARNING_RATE: f64 = 1e-3;
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -48,28 +68,77 @@ fn main() -> Result<()> {
|
||||
println!("action space: {}", env.action_space());
|
||||
println!("observation space: {:?}", env.observation_space());
|
||||
|
||||
let _num_obs = env.observation_space().iter().product::<usize>();
|
||||
let _num_actions = env.action_space();
|
||||
let size_state = env.observation_space().iter().product::<usize>();
|
||||
let size_action = env.action_space();
|
||||
|
||||
let mut agent = ddpg::DDPG::new(
|
||||
&Device::Cpu,
|
||||
size_state,
|
||||
size_action,
|
||||
true,
|
||||
ACTOR_LEARNING_RATE,
|
||||
CRITIC_LEARNING_RATE,
|
||||
GAMMA,
|
||||
TAU,
|
||||
REPLAY_BUFFER_CAPACITY,
|
||||
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||
)?;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for episode in 0..MAX_EPISODES {
|
||||
let mut obs = env.reset(episode as u64)?;
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let actions = rng.gen_range(-2.0..2.0);
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
action = action.clamp(-2.0, 2.0);
|
||||
|
||||
let step = env.step(vec![actions])?;
|
||||
let step = env.step(vec![action])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
if step.is_done {
|
||||
agent.remember(
|
||||
&state,
|
||||
&Tensor::new(vec![action], &Device::Cpu)?,
|
||||
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
|
||||
&step.state,
|
||||
step.terminated,
|
||||
step.truncated,
|
||||
);
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
break;
|
||||
}
|
||||
obs = step.obs;
|
||||
state = step.state;
|
||||
}
|
||||
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
|
||||
for _ in 0..TRAINING_ITERATIONS {
|
||||
agent.train(TRAINING_BATCH_SIZE)?;
|
||||
}
|
||||
}
|
||||
|
||||
println!("Testing...");
|
||||
agent.train = false;
|
||||
for episode in 0..10 {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
action = action.clamp(-2.0, 2.0);
|
||||
|
||||
let step = env.step(vec![action])?;
|
||||
total_reward += step.reward;
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
break;
|
||||
}
|
||||
state = step.state;
|
||||
}
|
||||
println!("episode {episode} with total reward of {total_reward}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
13
candle-examples/examples/vgg/README.md
Normal file
13
candle-examples/examples/vgg/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
## VGG Model Implementation
|
||||
|
||||
This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library.
|
||||
|
||||
The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image.
|
||||
|
||||
You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
|
||||
```
|
||||
|
||||
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
77
candle-examples/examples/vgg/main.rs
Normal file
77
candle-examples/examples/vgg/main.rs
Normal file
@ -0,0 +1,77 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{ModuleT, VarBuilder};
|
||||
use candle_transformers::models::vgg::{Models, Vgg};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
Vgg13,
|
||||
Vgg16,
|
||||
Vgg19,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Variant of the model to use.
|
||||
#[arg(value_enum, long, default_value_t = Which::Vgg13)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match args.which {
|
||||
Which::Vgg13 => "timm/vgg13.tv_in1k",
|
||||
Which::Vgg16 => "timm/vgg16.tv_in1k",
|
||||
Which::Vgg19 => "timm/vgg19.tv_in1k",
|
||||
};
|
||||
let api = api.model(repo.into());
|
||||
let filename = "model.safetensors";
|
||||
let model_file = api.get(filename)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = match args.which {
|
||||
Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?,
|
||||
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
|
||||
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
|
||||
};
|
||||
let logits = model.forward_t(&image, /*train=*/ false)?;
|
||||
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
|
||||
// Sort the predictions and take the top 5
|
||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||
|
||||
// Print the top predictions
|
||||
for &(i, p) in &top {
|
||||
println!(
|
||||
"{:50}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[i],
|
||||
p * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,7 +1,5 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
|
||||
};
|
||||
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub struct Multiples {
|
||||
@ -76,7 +74,6 @@ impl Module for Upsample {
|
||||
#[derive(Debug)]
|
||||
struct ConvBlock {
|
||||
conv: Conv2d,
|
||||
bn: BatchNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
@ -96,11 +93,10 @@ impl ConvBlock {
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||
Ok(Self {
|
||||
conv,
|
||||
bn,
|
||||
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
|
||||
})
|
||||
}
|
||||
@ -110,7 +106,6 @@ impl Module for ConvBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.conv.forward(xs)?;
|
||||
let xs = self.bn.forward(&xs)?;
|
||||
candle_nn::ops::silu(&xs)
|
||||
}
|
||||
}
|
||||
|
@ -9,8 +9,11 @@ pub enum Activation {
|
||||
#[serde(rename = "gated-gelu")]
|
||||
NewGelu,
|
||||
Relu,
|
||||
Relu2,
|
||||
Relu6,
|
||||
Silu,
|
||||
Sigmoid,
|
||||
Swish,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
}
|
||||
@ -22,8 +25,11 @@ impl super::Module for Activation {
|
||||
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
|
||||
Self::NewGelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
Self::Relu2 => xs.relu()?.sqr(),
|
||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
Self::Swish => xs * crate::ops::sigmoid(xs)?,
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
}
|
||||
|
@ -100,9 +100,23 @@ impl BatchNorm {
|
||||
num_features,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchNorm {
|
||||
pub fn running_mean(&self) -> &Tensor {
|
||||
&self.running_mean
|
||||
}
|
||||
|
||||
pub fn running_var(&self) -> &Tensor {
|
||||
&self.running_var
|
||||
}
|
||||
|
||||
pub fn eps(&self) -> f64 {
|
||||
self.eps
|
||||
}
|
||||
|
||||
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
|
||||
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
|
||||
}
|
||||
|
||||
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
|
@ -1,4 +1,5 @@
|
||||
//! Convolution Layers.
|
||||
use crate::BatchNorm;
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -115,6 +116,26 @@ impl Conv2d {
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
|
||||
pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
|
||||
if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
|
||||
let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
|
||||
let weight = self
|
||||
.weight()
|
||||
.broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
|
||||
let bias = match &self.bias {
|
||||
None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
|
||||
Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
|
||||
};
|
||||
Ok(Self {
|
||||
weight,
|
||||
bias: Some(bias),
|
||||
config: self.config,
|
||||
})
|
||||
} else {
|
||||
candle::bail!("batch norm does not have weight_and_bias")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
|
@ -36,3 +36,38 @@ impl<'a> Func<'a> {
|
||||
Self { f: Arc::new(f) }
|
||||
}
|
||||
}
|
||||
|
||||
/// A layer defined by a simple closure.
|
||||
#[derive(Clone)]
|
||||
pub struct FuncT<'a> {
|
||||
#[allow(clippy::type_complexity)]
|
||||
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<'a> std::fmt::Debug for FuncT<'a> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "func")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn func_t<'a, F>(f: F) -> FuncT<'a>
|
||||
where
|
||||
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||
{
|
||||
FuncT { f: Arc::new(f) }
|
||||
}
|
||||
|
||||
impl<'a> super::ModuleT for FuncT<'a> {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
(*self.f)(xs, train)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FuncT<'a> {
|
||||
pub fn new<F>(f: F) -> Self
|
||||
where
|
||||
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
|
||||
{
|
||||
Self { f: Arc::new(f) }
|
||||
}
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ pub use conv::{
|
||||
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||
};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use func::{func, Func};
|
||||
pub use func::{func, func_t, Func, FuncT};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential};
|
||||
pub use var_builder::VarBuilder;
|
||||
pub use var_map::VarMap;
|
||||
|
||||
pub use candle::Module;
|
||||
pub use candle::{Module, ModuleT};
|
||||
|
@ -84,6 +84,12 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::ModuleT for Dropout {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
self.forward(xs, train)
|
||||
}
|
||||
}
|
||||
|
||||
struct SoftmaxLastDim;
|
||||
|
||||
impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
|
@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = "0.19"
|
||||
pyo3-build-config = "0.20"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
@ -53,3 +53,39 @@ class Tensor:
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
|
||||
class bf16(DType):
|
||||
pass
|
||||
@ -26,21 +26,21 @@ class i64(DType):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor filled with ones.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
|
||||
def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor with random values.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
|
||||
def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor with random values from a normal distribution.
|
||||
"""
|
||||
@ -67,7 +67,7 @@ class u8(DType):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
"""
|
||||
Creates a new tensor filled with zeros.
|
||||
"""
|
||||
@ -124,16 +124,46 @@ class Tensor:
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||
"""
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
@ -159,6 +189,11 @@ class Tensor:
|
||||
Divide a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def abs(self) -> Tensor:
|
||||
"""
|
||||
Performs the `abs` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the maximum value(s) across the selected dimension.
|
||||
@ -174,7 +209,7 @@ class Tensor:
|
||||
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
def broadcast_as(self, shape: Sequence[int]) -> Tensor:
|
||||
def broadcast_as(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Broadcasts the tensor to the given shape.
|
||||
"""
|
||||
@ -184,7 +219,7 @@ class Tensor:
|
||||
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||
"""
|
||||
pass
|
||||
def broadcast_left(self, shape: Sequence[int]) -> Tensor:
|
||||
def broadcast_left(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
"""
|
||||
@ -308,6 +343,12 @@ class Tensor:
|
||||
ranges from `start` to `start + len`.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def nelement(self) -> int:
|
||||
"""
|
||||
Gets the tensor's element count.
|
||||
"""
|
||||
pass
|
||||
def powf(self, p: float) -> Tensor:
|
||||
"""
|
||||
Performs the `pow` operation on the tensor with the given exponent.
|
||||
@ -329,7 +370,7 @@ class Tensor:
|
||||
Get the `recip` of the tensor.
|
||||
"""
|
||||
pass
|
||||
def reshape(self, shape: Sequence[int]) -> Tensor:
|
||||
def reshape(self, *shape: Shape) -> Tensor:
|
||||
"""
|
||||
Reshapes the tensor to the given shape.
|
||||
"""
|
||||
@ -396,6 +437,11 @@ class Tensor:
|
||||
Convert the tensor to a new dtype.
|
||||
"""
|
||||
pass
|
||||
def to_torch(self) -> torch.Tensor:
|
||||
"""
|
||||
Converts candle's tensor to pytorch's tensor
|
||||
"""
|
||||
pass
|
||||
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
|
70
candle-pyo3/py_src/candle/testing/__init__.py
Normal file
70
candle-pyo3/py_src/candle/testing/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
|
||||
|
||||
_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
|
||||
|
||||
|
||||
def _assert_tensor_metadata(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
if check_device:
|
||||
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
|
||||
|
||||
if check_dtype:
|
||||
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
|
||||
|
||||
if check_layout:
|
||||
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
|
||||
|
||||
if check_stride:
|
||||
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
|
||||
|
||||
|
||||
def assert_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Asserts that two tensors are exact equals.
|
||||
"""
|
||||
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
|
||||
|
||||
|
||||
def assert_almost_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
|
||||
|
||||
Computes: |actual - expected| ≤ atol + rtol x |expected|
|
||||
"""
|
||||
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||
|
||||
# Secure against overflow of u32 and u8 tensors
|
||||
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
|
||||
actual = actual.to(candle.i64)
|
||||
expected = expected.to(candle.i64)
|
||||
|
||||
diff = (actual - expected).abs()
|
||||
|
||||
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
|
||||
|
||||
assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
|
@ -18,3 +18,5 @@ Device = TypeVar("Device", CPU, CUDA)
|
||||
Scalar = Union[int, float]
|
||||
|
||||
Index = Union[int, slice, None, "Ellipsis"]
|
||||
|
||||
Shape = Union[int, Sequence[int]]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,8 +1,11 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
use pyo3::ToPyObject;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::os::raw::c_long;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -16,26 +19,13 @@ extern crate accelerate_src;
|
||||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod shape;
|
||||
use shape::{PyShape, PyShapeWithHole};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyShape> for ::candle::Shape {
|
||||
fn from(val: PyShape) -> Self {
|
||||
val.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
/// A `candle` tensor.
|
||||
@ -145,9 +135,10 @@ macro_rules! pydtype {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pydtype!(i64, |v| v);
|
||||
pydtype!(u8, |v| v);
|
||||
pydtype!(u32, |v| v);
|
||||
pydtype!(i64, |v| v);
|
||||
pydtype!(f16, f32::from);
|
||||
pydtype!(bf16, f32::from);
|
||||
pydtype!(f32, |v| v);
|
||||
@ -211,6 +202,16 @@ enum Indexer {
|
||||
IndexSelect(Tensor),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TorchTensor(PyObject);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
|
||||
Ok(TorchTensor(numpy_value))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
@ -246,6 +247,8 @@ impl PyTensor {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
||||
return PyTensor::new(py, numpy);
|
||||
} else {
|
||||
let ty = data.as_ref(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
@ -299,6 +302,18 @@ impl PyTensor {
|
||||
M(py).map(self)
|
||||
}
|
||||
|
||||
/// Converts candle's tensor to pytorch's tensor
|
||||
/// &RETURNS&: torch.Tensor
|
||||
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let candle_values = self.values(py)?;
|
||||
let torch_tensor: PyObject = py
|
||||
.import("torch")?
|
||||
.getattr("tensor")?
|
||||
.call1((candle_values,))?
|
||||
.extract()?;
|
||||
Ok(torch_tensor)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's shape.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
@ -306,6 +321,13 @@ impl PyTensor {
|
||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's element count.
|
||||
/// &RETURNS&: int
|
||||
fn nelement(&self) -> usize {
|
||||
self.0.elem_count()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's strides.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
@ -342,6 +364,12 @@ impl PyTensor {
|
||||
self.__repr__()
|
||||
}
|
||||
|
||||
/// Performs the `abs` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn abs(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Performs the `sin` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sin(&self) -> PyResult<Self> {
|
||||
@ -659,26 +687,90 @@ impl PyTensor {
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
/// Rich-compare two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
|
||||
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
||||
let t = match op {
|
||||
CompareOp::Eq => lhs.eq(rhs),
|
||||
CompareOp::Ne => lhs.ne(rhs),
|
||||
CompareOp::Lt => lhs.lt(rhs),
|
||||
CompareOp::Le => lhs.le(rhs),
|
||||
CompareOp::Gt => lhs.gt(rhs),
|
||||
CompareOp::Ge => lhs.ge(rhs),
|
||||
};
|
||||
Ok(PyTensor(t.map_err(wrap_err)?))
|
||||
};
|
||||
if let Ok(rhs) = rhs.extract::<PyTensor>() {
|
||||
if self.0.shape() == rhs.0.shape() {
|
||||
compare(&self.0, &rhs.0)
|
||||
} else {
|
||||
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
|
||||
let broadcast_shape = self
|
||||
.0
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
|
||||
.map_err(wrap_err)?;
|
||||
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
compare(&broadcasted_lhs, &broadcasted_rhs)
|
||||
}
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
let scalar_tensor = Tensor::new(rhs, self.0.device())
|
||||
.map_err(wrap_err)?
|
||||
.to_dtype(self.0.dtype())
|
||||
.map_err(wrap_err)?
|
||||
.broadcast_as(self.0.shape())
|
||||
.map_err(wrap_err)?;
|
||||
|
||||
compare(&self.0, &scalar_tensor)
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
|
||||
}
|
||||
}
|
||||
|
||||
fn __hash__(&self) -> u64 {
|
||||
// we have overridden __richcmp__ => py03 wants us to also override __hash__
|
||||
// we simply hash the address of the tensor
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let pointer = &self.0 as *const Tensor;
|
||||
let address = pointer as usize;
|
||||
address.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Reshapes the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.reshape(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Broadcasts the tensor to the given shape.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||
fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.broadcast_as(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||
/// &RETURNS&: Tensor
|
||||
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||
fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0
|
||||
.broadcast_left(shape.to_absolute(&self.0)?)
|
||||
.map_err(wrap_err)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, dim:int)")]
|
||||
@ -891,21 +983,21 @@ impl PyTensor {
|
||||
}
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
if let Some(any) = kwargs.get_item("dtype") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("dtype") {
|
||||
handle_duplicates(
|
||||
&mut dtype,
|
||||
any.extract::<PyDType>(),
|
||||
"cannot specify multiple dtypes",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("device") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("device") {
|
||||
handle_duplicates(
|
||||
&mut device,
|
||||
any.extract::<PyDevice>(),
|
||||
"cannot specify multiple devices",
|
||||
)?;
|
||||
}
|
||||
if let Some(any) = kwargs.get_item("other") {
|
||||
if let Ok(Some(any)) = kwargs.get_item("other") {
|
||||
handle_duplicates(
|
||||
&mut other,
|
||||
any.extract::<PyTensor>(),
|
||||
@ -1025,27 +1117,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values.
|
||||
/// &RETURNS&: Tensor
|
||||
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor with random values from a normal distribution.
|
||||
/// &RETURNS&: Tensor
|
||||
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with ones.
|
||||
/// &RETURNS&: Tensor
|
||||
fn ones(
|
||||
@ -1059,12 +1151,12 @@ fn ones(
|
||||
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||
};
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
|
||||
/// Creates a new tensor filled with zeros.
|
||||
/// &RETURNS&: Tensor
|
||||
fn zeros(
|
||||
@ -1078,7 +1170,7 @@ fn zeros(
|
||||
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
|
||||
};
|
||||
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
|
||||
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
|
||||
let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
}
|
||||
|
||||
@ -1480,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyDType>()?;
|
||||
m.add("u8", PyDType(DType::U8))?;
|
||||
m.add("u32", PyDType(DType::U32))?;
|
||||
m.add("i16", PyDType(DType::I64))?;
|
||||
m.add("i64", PyDType(DType::I64))?;
|
||||
m.add("bf16", PyDType(DType::BF16))?;
|
||||
m.add("f16", PyDType(DType::F16))?;
|
||||
m.add("f32", PyDType(DType::F32))?;
|
||||
|
99
candle-pyo3/src/shape.rs
Normal file
99
candle-pyo3/src/shape.rs
Normal file
@ -0,0 +1,99 @@
|
||||
use ::candle::Tensor;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// Represents an absolute shape e.g. (1, 2, 3)
|
||||
pub struct PyShape(Vec<usize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShape {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
if ob.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"Shape cannot be None",
|
||||
));
|
||||
}
|
||||
|
||||
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
|
||||
if tuple.len() == 1 {
|
||||
let first_element = tuple.get_item(0)?;
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
|
||||
Ok(PyShape(dims))
|
||||
} else {
|
||||
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
|
||||
Ok(PyShape(dims))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyShape> for ::candle::Shape {
|
||||
fn from(val: PyShape) -> Self {
|
||||
val.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
/// Represents a shape with a hole in it e.g. (1, -1, 3)
|
||||
pub struct PyShapeWithHole(Vec<isize>);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
if ob.is_none() {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
|
||||
"Shape cannot be None",
|
||||
));
|
||||
}
|
||||
|
||||
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
|
||||
let dims: Vec<isize> = if tuple.len() == 1 {
|
||||
let first_element = tuple.get_item(0)?;
|
||||
pyo3::FromPyObject::extract(first_element)?
|
||||
} else {
|
||||
pyo3::FromPyObject::extract(tuple)?
|
||||
};
|
||||
|
||||
// Ensure we have only positive numbers and at most one "hole" (-1)
|
||||
let negative_ones = dims.iter().filter(|&&x| x == -1).count();
|
||||
let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0);
|
||||
if negative_ones > 1 || any_invalid_dimensions {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"Invalid dimension in shape: {:?}",
|
||||
dims
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(PyShapeWithHole(dims))
|
||||
}
|
||||
}
|
||||
|
||||
impl PyShapeWithHole {
|
||||
/// Returns `true` if the shape is absolute e.g. (1, 2, 3)
|
||||
pub fn is_absolute(&self) -> bool {
|
||||
self.0.iter().all(|x| *x > 0)
|
||||
}
|
||||
|
||||
/// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
|
||||
pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
|
||||
if self.is_absolute() {
|
||||
return Ok(PyShape(
|
||||
self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut elements = t.elem_count();
|
||||
let mut new_dims: Vec<usize> = vec![];
|
||||
for dim in self.0.iter() {
|
||||
if *dim > 0 {
|
||||
new_dims.push(*dim as usize);
|
||||
elements /= *dim as usize;
|
||||
} else if *dim == -1 {
|
||||
new_dims.push(elements);
|
||||
} else {
|
||||
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
|
||||
"Invalid dimension in shape: {}",
|
||||
dim
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(PyShape(new_dims))
|
||||
}
|
||||
}
|
@ -13,7 +13,7 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
||||
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
"""
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n"
|
||||
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
|
||||
RETURN_TYPE_MARKER = "&RETURNS&: "
|
||||
ADDITIONAL_TYPEHINTS = {}
|
||||
|
14
candle-pyo3/test_pytorch.py
Normal file
14
candle-pyo3/test_pytorch.py
Normal file
@ -0,0 +1,14 @@
|
||||
import candle
|
||||
import torch
|
||||
|
||||
# convert from candle tensor to torch tensor
|
||||
t = candle.randn((3, 512, 512))
|
||||
torch_tensor = t.to_torch()
|
||||
print(torch_tensor)
|
||||
print(type(torch_tensor))
|
||||
|
||||
# convert from torch tensor to candle tensor
|
||||
t = torch.randn((3, 512, 512))
|
||||
candle_tensor = candle.Tensor(t)
|
||||
print(candle_tensor)
|
||||
print(type(candle_tensor))
|
33
candle-pyo3/tests/bindings/test_testing.py
Normal file
33
candle-pyo3/tests/bindings/test_testing.py
Normal file
@ -0,0 +1,33 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.testing import assert_equal, assert_almost_equal
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
|
||||
def test_assert_equal_asserts_correctly(dtype: candle.DType):
|
||||
a = Tensor([1, 2, 3]).to(dtype)
|
||||
b = Tensor([1, 2, 3]).to(dtype)
|
||||
assert_equal(a, b)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_equal(a, b + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
|
||||
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
|
||||
a = Tensor([1, 2, 3]).to(dtype)
|
||||
b = Tensor([1, 2, 3]).to(dtype)
|
||||
assert_almost_equal(a, b)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_almost_equal(a, b + 1)
|
||||
|
||||
assert_almost_equal(a, b + 1, atol=20)
|
||||
assert_almost_equal(a, b + 1, rtol=20)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_almost_equal(a, b + 1, atol=0.9)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_almost_equal(a, b + 1, rtol=0.1)
|
31
candle-pyo3/tests/native/test_shape.py
Normal file
31
candle-pyo3/tests/native/test_shape.py
Normal file
@ -0,0 +1,31 @@
|
||||
from candle import Tensor
|
||||
from candle import rand
|
||||
import pytest
|
||||
|
||||
|
||||
def test_absolute_shapes_are_valid():
|
||||
a = rand((10, 20))
|
||||
assert a.shape == (10, 20)
|
||||
|
||||
b = rand(10, 20)
|
||||
assert b.shape == (10, 20)
|
||||
pytest.raises(OverflowError, lambda: rand((10, 20, -1)))
|
||||
pytest.raises(OverflowError, lambda: rand(-1, 20))
|
||||
pytest.raises(TypeError, lambda: rand("foo", True))
|
||||
|
||||
|
||||
def test_relative_shapes_are_valid():
|
||||
a = rand(10, 20)
|
||||
a = a.reshape((1, -1))
|
||||
assert a.shape == (1, 200)
|
||||
|
||||
b = rand(10, 20)
|
||||
b = b.reshape(-1, 1)
|
||||
assert b.shape == (200, 1)
|
||||
|
||||
c = rand(10, 20)
|
||||
pytest.raises(TypeError, lambda: c.reshape(1, "foo"))
|
||||
pytest.raises(ValueError, lambda: c.reshape(1, -2))
|
||||
pytest.raises(ValueError, lambda: c.reshape((-2, 1)))
|
||||
pytest.raises(ValueError, lambda: c.reshape((0, 1)))
|
||||
pytest.raises(ValueError, lambda: c.reshape((1, -1, -1)))
|
@ -1,6 +1,7 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.utils import cuda_is_available
|
||||
from candle.testing import assert_equal
|
||||
import pytest
|
||||
|
||||
|
||||
@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def assert_bool(t: Tensor, expected: bool):
|
||||
assert t.shape == ()
|
||||
assert str(t.dtype) == str(candle.u8)
|
||||
assert bool(t.values()) == expected
|
||||
|
||||
|
||||
def test_tensor_supports_equality_opperations_with_scalars():
|
||||
t = Tensor(42.0)
|
||||
|
||||
assert_bool(t == 42.0, True)
|
||||
assert_bool(t == 43.0, False)
|
||||
|
||||
assert_bool(t != 42.0, False)
|
||||
assert_bool(t != 43.0, True)
|
||||
|
||||
assert_bool(t > 41.0, True)
|
||||
assert_bool(t > 42.0, False)
|
||||
|
||||
assert_bool(t >= 41.0, True)
|
||||
assert_bool(t >= 42.0, True)
|
||||
|
||||
assert_bool(t < 43.0, True)
|
||||
assert_bool(t < 42.0, False)
|
||||
|
||||
assert_bool(t <= 43.0, True)
|
||||
assert_bool(t <= 42.0, True)
|
||||
|
||||
|
||||
def test_tensor_supports_equality_opperations_with_tensors():
|
||||
t = Tensor(42.0)
|
||||
same = Tensor(42.0)
|
||||
other = Tensor(43.0)
|
||||
|
||||
assert_bool(t == same, True)
|
||||
assert_bool(t == other, False)
|
||||
|
||||
assert_bool(t != same, False)
|
||||
assert_bool(t != other, True)
|
||||
|
||||
assert_bool(t > same, False)
|
||||
assert_bool(t > other, False)
|
||||
|
||||
assert_bool(t >= same, True)
|
||||
assert_bool(t >= other, False)
|
||||
|
||||
assert_bool(t < same, False)
|
||||
assert_bool(t < other, True)
|
||||
|
||||
assert_bool(t <= same, True)
|
||||
assert_bool(t <= other, True)
|
||||
|
||||
|
||||
def test_tensor_equality_opperations_can_broadcast():
|
||||
# Create a decoder attention mask as a test case
|
||||
# e.g.
|
||||
# [[1,0,0]
|
||||
# [1,1,0]
|
||||
# [1,1,1]]
|
||||
mask_cond = candle.Tensor([0, 1, 2])
|
||||
mask = mask_cond < (mask_cond + 1).reshape((3, 1))
|
||||
assert mask.shape == (3, 3)
|
||||
assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
|
||||
|
||||
|
||||
def test_tensor_can_be_hashed():
|
||||
t = Tensor(42.0)
|
||||
other = Tensor(42.0)
|
||||
# Hash should represent a unique tensor
|
||||
assert hash(t) != hash(other)
|
||||
assert hash(t) == hash(t)
|
||||
|
||||
|
||||
def test_tensor_can_be_expanded_with_none():
|
||||
t = candle.rand((12, 12))
|
||||
|
||||
|
@ -11,6 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
|
@ -1,3 +1,4 @@
|
||||
use super::with_tracing::{linear, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
@ -32,33 +33,6 @@ impl HiddenActLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Self { weight, bias, span }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
@ -77,8 +51,10 @@ impl LayerNorm {
|
||||
span,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -180,12 +156,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
@ -195,7 +165,9 @@ impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Dropout {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
@ -316,7 +288,9 @@ impl BertSelfAttention {
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
@ -391,7 +365,9 @@ impl BertAttention {
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertAttention {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
@ -416,7 +392,9 @@ impl BertIntermediate {
|
||||
span: tracing::span!(tracing::Level::TRACE, "inter"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertIntermediate {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
@ -478,7 +456,9 @@ impl BertLayer {
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertLayer {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
@ -507,7 +487,9 @@ impl BertEncoder {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(BertEncoder { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEncoder {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
|
@ -2,8 +2,9 @@ use super::blip_text;
|
||||
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct VisionConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
@ -16,7 +17,7 @@ pub struct VisionConfig {
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub text_config: blip_text::Config,
|
||||
pub vision_config: VisionConfig,
|
||||
@ -299,4 +300,8 @@ impl BlipForConditionalGeneration {
|
||||
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
|
||||
&mut self.text_decoder
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.text_decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,9 @@
|
||||
use super::with_tracing::{linear, Embedding, Linear};
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
|
369
candle-transformers/src/models/jina_bert.rs
Normal file
369
candle-transformers/src/models/jina_bert.rs
Normal file
@ -0,0 +1,369 @@
|
||||
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PositionEmbeddingType {
|
||||
Absolute,
|
||||
Alibi,
|
||||
}
|
||||
|
||||
// https://huggingface.co/jinaai/jina-bert-implementation/blob/main/configuration_bert.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub type_vocab_size: usize,
|
||||
pub initializer_range: f64,
|
||||
pub layer_norm_eps: f64,
|
||||
pub pad_token_id: usize,
|
||||
pub position_embedding_type: PositionEmbeddingType,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn v2_base() -> Self {
|
||||
// https://huggingface.co/jinaai/jina-embeddings-v2-base-en/blob/main/config.json
|
||||
Self {
|
||||
vocab_size: 30528,
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
max_position_embeddings: 8192,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Alibi,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
// no position_embeddings as we only support alibi.
|
||||
token_type_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
|
||||
let token_type_embeddings = Embedding::new(
|
||||
cfg.type_vocab_size,
|
||||
cfg.hidden_size,
|
||||
vb.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())?
|
||||
.broadcast_left(b_size)?
|
||||
.apply(&self.token_type_embeddings)?;
|
||||
let embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let all_head_size = cfg.num_attention_heads * attention_head_size;
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut x_shape = xs.dims().to_vec();
|
||||
x_shape.pop();
|
||||
x_shape.push(self.num_attention_heads);
|
||||
x_shape.push(self.attention_head_size);
|
||||
xs.reshape(x_shape)?.transpose(1, 2)?.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(xs)?;
|
||||
let key_layer = self.key.forward(xs)?;
|
||||
let value_layer = self.value.forward(xs)?;
|
||||
|
||||
let query_layer = self.transpose_for_scores(&query_layer)?;
|
||||
let key_layer = self.transpose_for_scores(&key_layer)?;
|
||||
let value_layer = self.transpose_for_scores(&value_layer)?;
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_scores = attention_scores.broadcast_add(bias)?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attention_scores)?
|
||||
};
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten_from(D::Minus2)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-out"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.dense.forward(xs)?;
|
||||
self.layer_norm.forward(&(xs + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertAttention {
|
||||
self_attention: BertSelfAttention,
|
||||
self_output: BertSelfOutput,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?;
|
||||
let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(xs, bias)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, xs)?;
|
||||
Ok(attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertGLUMLP {
|
||||
gated_layers: Linear,
|
||||
act: candle_nn::Activation,
|
||||
wo: Linear,
|
||||
layernorm: LayerNorm,
|
||||
intermediate_size: usize,
|
||||
}
|
||||
|
||||
impl BertGLUMLP {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let gated_layers = linear_no_bias(
|
||||
cfg.hidden_size,
|
||||
cfg.intermediate_size * 2,
|
||||
vb.pp("gated_layers"),
|
||||
)?;
|
||||
let act = candle_nn::Activation::Gelu; // geglu
|
||||
let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?;
|
||||
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
|
||||
Ok(Self {
|
||||
gated_layers,
|
||||
act,
|
||||
wo,
|
||||
layernorm,
|
||||
intermediate_size: cfg.intermediate_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertGLUMLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.gated_layers)?;
|
||||
let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?;
|
||||
let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
|
||||
let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo);
|
||||
(xs + residual)?.apply(&self.layernorm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertLayer {
|
||||
attention: BertAttention,
|
||||
mlp: BertGLUMLP,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::new(vb.pp("attention"), cfg)?;
|
||||
let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
mlp,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.attention.forward(xs, bias)?.apply(&self.mlp)
|
||||
}
|
||||
}
|
||||
|
||||
fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
let n_heads = cfg.num_attention_heads;
|
||||
let seq_len = cfg.max_position_embeddings;
|
||||
let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?;
|
||||
let alibi_bias = {
|
||||
let a1 = alibi_bias.reshape((1, seq_len))?;
|
||||
let a2 = alibi_bias.reshape((seq_len, 1))?;
|
||||
a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)?
|
||||
};
|
||||
let mut n_heads2 = 1;
|
||||
while n_heads2 < n_heads {
|
||||
n_heads2 *= 2
|
||||
}
|
||||
let slopes = (1..=n_heads2)
|
||||
.map(|v| -1f32 / 2f32.powf((v * 8) as f32 / n_heads2 as f32))
|
||||
.collect::<Vec<_>>();
|
||||
let slopes = if n_heads2 == n_heads {
|
||||
slopes
|
||||
} else {
|
||||
slopes
|
||||
.iter()
|
||||
.skip(1)
|
||||
.step_by(2)
|
||||
.chain(slopes.iter().step_by(2))
|
||||
.take(n_heads)
|
||||
.cloned()
|
||||
.collect::<Vec<f32>>()
|
||||
};
|
||||
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
|
||||
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct BertEncoder {
|
||||
alibi: Tensor,
|
||||
layers: Vec<BertLayer>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
if cfg.position_embedding_type != PositionEmbeddingType::Alibi {
|
||||
candle::bail!("only alibi is supported as a position-embedding-type")
|
||||
}
|
||||
let layers = (0..cfg.num_hidden_layers)
|
||||
.map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
|
||||
Ok(Self {
|
||||
alibi,
|
||||
layers,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let seq_len = xs.dim(1)?;
|
||||
let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, &alibi_bias)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BertModel {
|
||||
embeddings: BertEmbeddings,
|
||||
encoder: BertEncoder,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
pub fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?;
|
||||
let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BertModel {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
@ -81,21 +82,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
@ -150,12 +136,6 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
|
@ -17,7 +17,20 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny() -> Self {
|
||||
pub fn tiny_260k() -> Self {
|
||||
Self {
|
||||
dim: 64,
|
||||
hidden_dim: 768,
|
||||
n_layers: 5,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 4,
|
||||
vocab_size: 32000,
|
||||
seq_len: 512,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_15m() -> Self {
|
||||
Self {
|
||||
dim: 288,
|
||||
hidden_dim: 768,
|
||||
@ -29,6 +42,32 @@ impl Config {
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_42m() -> Self {
|
||||
Self {
|
||||
dim: 512,
|
||||
hidden_dim: 768,
|
||||
n_layers: 8,
|
||||
n_heads: 8,
|
||||
n_kv_heads: 8,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_110m() -> Self {
|
||||
Self {
|
||||
dim: 768,
|
||||
hidden_dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
n_kv_heads: 12,
|
||||
vocab_size: 32000,
|
||||
seq_len: 1024,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -36,9 +75,9 @@ pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub cos: Tensor,
|
||||
pub sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
@ -75,7 +114,7 @@ impl Cache {
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
@ -1,9 +1,8 @@
|
||||
use anyhow::Result;
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use candle::{DType, Device, IndexOp, Shape, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
use crate::model::Config;
|
||||
use super::llama2_c::Config;
|
||||
|
||||
pub struct TransformerWeights {
|
||||
// token embedding table
|
520
candle-transformers/src/models/marian.rs
Normal file
520
candle-transformers/src/models/marian.rs
Normal file
@ -0,0 +1,520 @@
|
||||
use super::with_tracing::{linear, Embedding, Linear};
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub decoder_vocab_size: Option<usize>,
|
||||
pub max_position_embeddings: usize,
|
||||
pub encoder_layers: usize,
|
||||
pub encoder_ffn_dim: usize,
|
||||
pub encoder_attention_heads: usize,
|
||||
pub decoder_layers: usize,
|
||||
pub decoder_ffn_dim: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub use_cache: bool,
|
||||
pub is_encoder_decoder: bool,
|
||||
pub activation_function: candle_nn::Activation,
|
||||
pub d_model: usize,
|
||||
pub decoder_start_token_id: u32,
|
||||
pub scale_embedding: bool,
|
||||
pub pad_token_id: u32,
|
||||
pub eos_token_id: u32,
|
||||
pub forced_eos_token_id: u32,
|
||||
pub share_encoder_decoder_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en/blob/main/config.json
|
||||
pub fn opus_mt_tc_big_fr_en() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Relu,
|
||||
d_model: 1024,
|
||||
decoder_attention_heads: 16,
|
||||
decoder_ffn_dim: 4096,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 53016,
|
||||
decoder_vocab_size: Some(53017),
|
||||
encoder_attention_heads: 16,
|
||||
encoder_ffn_dim: 4096,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 43311,
|
||||
forced_eos_token_id: 43311,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 1024,
|
||||
pad_token_id: 53016,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 53017,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/Helsinki-NLP/opus-mt-fr-en/blob/main/config.json
|
||||
pub fn opus_mt_fr_en() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 59513,
|
||||
decoder_vocab_size: Some(59514),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 59513,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 59514,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SinusoidalPositionalEmbedding {
|
||||
emb: Embedding,
|
||||
}
|
||||
|
||||
impl SinusoidalPositionalEmbedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dev = vb.device();
|
||||
let dtype = vb.dtype();
|
||||
let num_positions = cfg.max_position_embeddings;
|
||||
let dim = cfg.d_model;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, num_positions as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((num_positions, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?;
|
||||
let emb = Embedding::from_weights(weights)?;
|
||||
Ok(Self { emb })
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let seq_len = input_ids.dim(1)?;
|
||||
Tensor::arange(
|
||||
past_kv_len as u32,
|
||||
(past_kv_len + seq_len) as u32,
|
||||
input_ids.device(),
|
||||
)?
|
||||
.apply(&self.emb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
out_proj: Linear,
|
||||
scaling: f64,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
is_decoder: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let num_heads = if is_decoder {
|
||||
cfg.decoder_attention_heads
|
||||
} else {
|
||||
cfg.encoder_attention_heads
|
||||
};
|
||||
let embed_dim = cfg.d_model;
|
||||
let head_dim = embed_dim / num_heads;
|
||||
let scaling = (head_dim as f64).powf(-0.5);
|
||||
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
|
||||
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
out_proj,
|
||||
scaling,
|
||||
num_heads,
|
||||
head_dim,
|
||||
kv_cache: None,
|
||||
is_decoder,
|
||||
})
|
||||
}
|
||||
|
||||
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
|
||||
tensor
|
||||
.reshape((bsz, (), self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
kv_states: Option<&Tensor>,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||
let (key_states, value_states) = match kv_states {
|
||||
None => {
|
||||
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
|
||||
if self.is_decoder {
|
||||
let kv_states = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((p_key_states, p_value_states)) => {
|
||||
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some(kv_states.clone());
|
||||
kv_states
|
||||
} else {
|
||||
(key_states, value_states)
|
||||
}
|
||||
}
|
||||
Some(kv_states) => {
|
||||
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
|
||||
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
|
||||
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
|
||||
let key_states = key_states.reshape(proj_shape)?;
|
||||
let value_states = value_states.reshape(proj_shape)?;
|
||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||
let attn_weights = match attn_mask {
|
||||
None => attn_weights,
|
||||
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
||||
};
|
||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_probs.matmul(&value_states)?;
|
||||
attn_output
|
||||
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
|
||||
.apply(&self.out_proj)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderLayer {
|
||||
self_attn: Attention,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
activation_fn: candle_nn::Activation,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl EncoderLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
|
||||
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
activation_fn: cfg.activation_function,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
|
||||
.apply(&self.self_attn_layer_norm)?;
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)?;
|
||||
(xs + residual)?.apply(&self.final_layer_norm)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
self_attn_layer_norm: LayerNorm,
|
||||
activation_fn: candle_nn::Activation,
|
||||
encoder_attn: Attention,
|
||||
encoder_attn_layer_norm: LayerNorm,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
final_layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
|
||||
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
self_attn_layer_norm,
|
||||
activation_fn: cfg.activation_function,
|
||||
encoder_attn,
|
||||
encoder_attn_layer_norm,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
|
||||
.apply(&self.self_attn_layer_norm)?;
|
||||
let xs = match encoder_xs {
|
||||
None => xs,
|
||||
Some(encoder_xs) => {
|
||||
let residual = &xs;
|
||||
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
|
||||
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
|
||||
}
|
||||
};
|
||||
let residual = &xs;
|
||||
let xs = xs
|
||||
.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)?;
|
||||
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_attn.reset_kv_cache();
|
||||
self.encoder_attn.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
embed_tokens: Embedding,
|
||||
embed_positions: SinusoidalPositionalEmbedding,
|
||||
layers: Vec<EncoderLayer>,
|
||||
embed_scale: Option<f64>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
|
||||
let mut layers = Vec::with_capacity(cfg.encoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.encoder_layers {
|
||||
let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
Some((cfg.d_model as f64).sqrt())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens: embed_tokens.clone(),
|
||||
embed_positions,
|
||||
layers,
|
||||
embed_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let xs = match self.embed_scale {
|
||||
None => xs,
|
||||
Some(scale) => (xs * scale)?,
|
||||
};
|
||||
let embed_pos = self
|
||||
.embed_positions
|
||||
.forward(&xs, past_kv_len)?
|
||||
.unsqueeze(0)?;
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
embed_tokens: Embedding,
|
||||
embed_positions: SinusoidalPositionalEmbedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
embed_scale: Option<f64>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
|
||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for idx in 0..cfg.decoder_layers {
|
||||
let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let embed_scale = if cfg.scale_embedding {
|
||||
Some((cfg.d_model as f64).sqrt())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens: embed_tokens.clone(),
|
||||
embed_positions,
|
||||
layers,
|
||||
embed_scale,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: Option<&Tensor>,
|
||||
past_kv_len: usize,
|
||||
attn_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let xs = match self.embed_scale {
|
||||
None => xs,
|
||||
Some(scale) => (xs * scale)?,
|
||||
};
|
||||
let embed_pos = self
|
||||
.embed_positions
|
||||
.forward(&xs, past_kv_len)?
|
||||
.unsqueeze(0)?;
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Model {
|
||||
shared: Embedding,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?;
|
||||
let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?;
|
||||
Ok(Self {
|
||||
shared,
|
||||
encoder,
|
||||
decoder,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.encoder.reset_kv_cache();
|
||||
self.decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MTModel {
|
||||
model: Model,
|
||||
lm_head: Linear,
|
||||
final_logits_bias: Tensor,
|
||||
}
|
||||
|
||||
impl MTModel {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
|
||||
let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
|
||||
let model = Model::new(cfg, vb.pp("model"))?;
|
||||
let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
model,
|
||||
lm_head,
|
||||
final_logits_bias,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encoder(&mut self) -> &mut Encoder {
|
||||
&mut self.model.encoder
|
||||
}
|
||||
|
||||
pub fn decoder(&mut self) -> &mut Decoder {
|
||||
&mut self.model.decoder
|
||||
}
|
||||
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_xs: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||
self.model
|
||||
.decoder
|
||||
.forward(xs, Some(encoder_xs), past_kv_len, &mask)?
|
||||
.apply(&self.lm_head)?
|
||||
.broadcast_add(&self.final_logits_bias)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.model.reset_kv_cache();
|
||||
}
|
||||
}
|
@ -73,6 +73,23 @@ impl Config {
|
||||
pad_vocab_size_multiple: 64,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/teknium/Phi-Hermes-1.3B/blob/main/config.json
|
||||
pub fn phi_hermes_1_3b() -> Self {
|
||||
Self {
|
||||
vocab_size: 50304,
|
||||
n_positions: 2048,
|
||||
n_embd: 2048,
|
||||
n_layer: 24,
|
||||
n_inner: None,
|
||||
n_head: 32,
|
||||
rotary_dim: usize::min(32, 2048 / 32),
|
||||
activation_function: Activation::NewGelu,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
tie_word_embeddings: false,
|
||||
pad_vocab_size_multiple: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -6,13 +6,19 @@ pub mod convmixer;
|
||||
pub mod dinov2;
|
||||
pub mod efficientnet;
|
||||
pub mod falcon;
|
||||
pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
pub mod llama2_c_weights;
|
||||
pub mod marian;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mpt;
|
||||
pub mod persimmon;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_llama2_c;
|
||||
pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
@ -23,6 +29,7 @@ pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
pub mod t5;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
|
56
candle-transformers/src/models/persimmon.rs
Normal file
56
candle-transformers/src/models/persimmon.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use candle::DType;
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PositionEmbeddingType {
|
||||
Absolute,
|
||||
Alibi,
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub initializer_range: f64,
|
||||
pub layer_norm_eps: f64,
|
||||
pub rms_norm_eps: f64,
|
||||
pub use_cache: bool,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub rope_theta: f64,
|
||||
pub qk_layernorm: bool,
|
||||
pub partial_rotary_factor: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn base_8b() -> Self {
|
||||
// https://huggingface.co/adept/persimmon-8b-base/blob/main/config.json
|
||||
Self {
|
||||
hidden_act: candle_nn::Activation::Relu,
|
||||
hidden_size: 4096,
|
||||
initializer_range: 0.02,
|
||||
intermediate_size: 16384,
|
||||
layer_norm_eps: 1e-05,
|
||||
max_position_embeddings: 16384,
|
||||
num_attention_heads: 64,
|
||||
num_hidden_layers: 36,
|
||||
num_key_value_heads: 64,
|
||||
qk_layernorm: true,
|
||||
rms_norm_eps: 1e-06,
|
||||
rope_theta: 25000.0,
|
||||
tie_word_embeddings: false,
|
||||
use_cache: true,
|
||||
vocab_size: 262144,
|
||||
partial_rotary_factor: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
@ -255,4 +255,7 @@ impl BlipForConditionalGeneration {
|
||||
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
|
||||
&mut self.text_decoder
|
||||
}
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.text_decoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
227
candle-transformers/src/models/quantized_llama2_c.rs
Normal file
227
candle-transformers/src/models/quantized_llama2_c.rs
Normal file
@ -0,0 +1,227 @@
|
||||
use super::llama2_c::{Cache, Config};
|
||||
use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor, D};
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
|
||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let q = q.transpose(1, 2)?.contiguous()?;
|
||||
let k = k.transpose(1, 2)?.contiguous()?;
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_key_value_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(3)?
|
||||
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
|
||||
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let size_in = cfg.dim;
|
||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_head: cfg.n_heads,
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h_size = cfg.dim;
|
||||
let i_size = cfg.hidden_dim;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
|
||||
Ok(Self::new(
|
||||
input_layernorm,
|
||||
attn,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct QLlama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl QLlama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
}
|
@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
pub use crate::models::stable_lm::Config;
|
||||
use crate::models::stable_lm::RotaryEmbedding;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
@ -43,7 +43,7 @@ impl Module for MLP {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -168,7 +168,7 @@ impl Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
@ -213,7 +213,7 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
|
@ -93,7 +93,7 @@ impl Default for Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerNorm {
|
||||
weight: Tensor,
|
||||
variance_epsilon: f64,
|
||||
@ -125,7 +125,7 @@ impl Module for T5LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseActDense {
|
||||
wi: QMatMul,
|
||||
wo: QMatMul,
|
||||
@ -156,7 +156,7 @@ impl Module for T5DenseActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseGatedActDense {
|
||||
wi_0: QMatMul,
|
||||
wi_1: QMatMul,
|
||||
@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerFF {
|
||||
dense_act: Option<T5DenseActDense>,
|
||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||
@ -236,7 +236,7 @@ impl Module for T5LayerFF {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Attention {
|
||||
q: QMatMul,
|
||||
k: QMatMul,
|
||||
@ -431,7 +431,7 @@ impl T5Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -470,7 +470,7 @@ impl T5LayerSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerCrossAttention {
|
||||
cross_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -512,7 +512,7 @@ impl T5LayerCrossAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Block {
|
||||
self_attn: T5LayerSelfAttention,
|
||||
cross_attn: Option<T5LayerCrossAttention>,
|
||||
@ -583,7 +583,7 @@ impl T5Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
@ -633,7 +633,7 @@ impl T5Stack {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5EncoderModel {
|
||||
encoder: T5Stack,
|
||||
device: Device,
|
||||
@ -666,7 +666,7 @@ impl T5EncoderModel {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5ForConditionalGeneration {
|
||||
encoder: T5Stack,
|
||||
decoder: T5Stack,
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub use crate::models::with_tracing::Linear;
|
||||
use candle::{Result, Tensor};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
|
||||
@ -9,13 +10,11 @@ pub mod tiny_vit;
|
||||
pub mod transformer;
|
||||
|
||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||
let inner = if bias {
|
||||
candle_nn::linear(in_dim, out_dim, vb)?
|
||||
if bias {
|
||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -85,16 +84,3 @@ impl Module for MlpBlock {
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
@ -102,6 +102,14 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ssd1b() -> Self {
|
||||
Self::sdxl()
|
||||
}
|
||||
|
||||
pub fn ssd1b2() -> Self {
|
||||
Self::sdxl2()
|
||||
}
|
||||
|
||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
|
||||
pub fn wuerstchen() -> Self {
|
||||
Self {
|
||||
|
@ -249,6 +249,71 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn ssd1b(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![
|
||||
bc(320, None, 5),
|
||||
bc(640, Some(2), 10),
|
||||
bc(1280, Some(10), 20),
|
||||
],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 2048,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
1024
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
1024
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::ssd1b(),
|
||||
clip2: Some(clip::Config::ssd1b2()),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
vae_weights: P,
|
||||
|
@ -118,7 +118,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerNorm {
|
||||
weight: Tensor,
|
||||
variance_epsilon: f64,
|
||||
@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseActDense {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseGatedActDense {
|
||||
wi_0: Linear,
|
||||
wi_1: Linear,
|
||||
@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerFF {
|
||||
dense_act: Option<T5DenseActDense>,
|
||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||
@ -261,7 +261,7 @@ impl Module for T5LayerFF {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Attention {
|
||||
q: Linear,
|
||||
k: Linear,
|
||||
@ -456,7 +456,7 @@ impl T5Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerCrossAttention {
|
||||
cross_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Block {
|
||||
self_attn: T5LayerSelfAttention,
|
||||
cross_attn: Option<T5LayerCrossAttention>,
|
||||
@ -608,7 +608,7 @@ impl T5Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
@ -658,7 +658,7 @@ impl T5Stack {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5EncoderModel {
|
||||
encoder: T5Stack,
|
||||
device: Device,
|
||||
@ -691,7 +691,7 @@ impl T5EncoderModel {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5ForConditionalGeneration {
|
||||
encoder: T5Stack,
|
||||
decoder: T5Stack,
|
||||
|
257
candle-transformers/src/models/vgg.rs
Normal file
257
candle-transformers/src/models/vgg.rs
Normal file
@ -0,0 +1,257 @@
|
||||
//! VGG-16 model implementation.
|
||||
//!
|
||||
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
|
||||
//! <https://arxiv.org/abs/1409.1556>
|
||||
use candle::{ModuleT, Result, Tensor};
|
||||
use candle_nn::{FuncT, VarBuilder};
|
||||
|
||||
// Enum representing the different VGG models
|
||||
pub enum Models {
|
||||
Vgg13,
|
||||
Vgg16,
|
||||
Vgg19,
|
||||
}
|
||||
|
||||
// Struct representing a VGG model
|
||||
#[derive(Debug)]
|
||||
pub struct Vgg<'a> {
|
||||
blocks: Vec<FuncT<'a>>,
|
||||
}
|
||||
|
||||
// Struct representing the configuration for the pre-logit layer
|
||||
struct PreLogitConfig {
|
||||
in_dim: (usize, usize, usize, usize),
|
||||
target_in: usize,
|
||||
target_out: usize,
|
||||
}
|
||||
|
||||
// Implementation of the VGG model
|
||||
impl<'a> Vgg<'a> {
|
||||
// Function to create a new VGG model
|
||||
pub fn new(vb: VarBuilder<'a>, model: Models) -> Result<Self> {
|
||||
let blocks = match model {
|
||||
Models::Vgg13 => vgg13_blocks(vb)?,
|
||||
Models::Vgg16 => vgg16_blocks(vb)?,
|
||||
Models::Vgg19 => vgg19_blocks(vb)?,
|
||||
};
|
||||
Ok(Self { blocks })
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of the forward pass for the VGG model
|
||||
impl ModuleT for Vgg<'_> {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
let mut xs = xs.unsqueeze(0)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply_t(block, train)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// Function to create a conv2d block
|
||||
// The block is composed of two conv2d layers followed by a max pool layer
|
||||
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
|
||||
let layers = convs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(_, &(in_c, out_c, name))| {
|
||||
candle_nn::conv2d(
|
||||
in_c,
|
||||
out_c,
|
||||
3,
|
||||
candle_nn::Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp(name),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(FuncT::new(move |xs, _train| {
|
||||
let mut xs = xs.clone();
|
||||
for layer in layers.iter() {
|
||||
xs = xs.apply(layer)?.relu()?
|
||||
}
|
||||
xs = xs.max_pool2d_with_stride(2, 2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Function to create a fully connected layer
|
||||
// The layer is composed of two linear layers followed by a dropout layer
|
||||
fn fully_connected(
|
||||
num_classes: usize,
|
||||
pre_logit_1: PreLogitConfig,
|
||||
pre_logit_2: PreLogitConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<FuncT> {
|
||||
let lin = get_weights_and_biases(
|
||||
&vb.pp("pre_logits.fc1"),
|
||||
pre_logit_1.in_dim,
|
||||
pre_logit_1.target_in,
|
||||
pre_logit_1.target_out,
|
||||
)?;
|
||||
let lin2 = get_weights_and_biases(
|
||||
&vb.pp("pre_logits.fc2"),
|
||||
pre_logit_2.in_dim,
|
||||
pre_logit_2.target_in,
|
||||
pre_logit_2.target_out,
|
||||
)?;
|
||||
let dropout1 = candle_nn::Dropout::new(0.5);
|
||||
let dropout2 = candle_nn::Dropout::new(0.5);
|
||||
let dropout3 = candle_nn::Dropout::new(0.5);
|
||||
Ok(FuncT::new(move |xs, train| {
|
||||
let xs = xs.reshape((1, pre_logit_1.target_out))?;
|
||||
let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
|
||||
let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
|
||||
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
|
||||
let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Function to get the weights and biases for a layer
|
||||
// This is required because the weights and biases are stored in different format than our linear layer expects
|
||||
fn get_weights_and_biases(
|
||||
vs: &VarBuilder,
|
||||
in_dim: (usize, usize, usize, usize),
|
||||
target_in: usize,
|
||||
target_out: usize,
|
||||
) -> Result<candle_nn::Linear> {
|
||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(in_dim, "weight", init_ws)?;
|
||||
let ws = ws.reshape((target_in, target_out))?;
|
||||
let bound = 1. / (target_out as f64).sqrt();
|
||||
let init_bs = candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_with_hints(target_in, "bias", init_bs)?;
|
||||
Ok(candle_nn::Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||
let num_classes = 1000;
|
||||
let blocks = vec![
|
||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
|
||||
conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?,
|
||||
conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?,
|
||||
conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?,
|
||||
fully_connected(
|
||||
num_classes,
|
||||
PreLogitConfig {
|
||||
in_dim: (4096, 512, 7, 7),
|
||||
target_in: 4096,
|
||||
target_out: 512 * 7 * 7,
|
||||
},
|
||||
PreLogitConfig {
|
||||
in_dim: (4096, 4096, 1, 1),
|
||||
target_in: 4096,
|
||||
target_out: 4096,
|
||||
},
|
||||
vb.clone(),
|
||||
)?,
|
||||
];
|
||||
Ok(blocks)
|
||||
}
|
||||
|
||||
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||
let num_classes = 1000;
|
||||
let blocks = vec![
|
||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
|
||||
conv2d_block(
|
||||
&[
|
||||
(128, 256, "features.10"),
|
||||
(256, 256, "features.12"),
|
||||
(256, 256, "features.14"),
|
||||
],
|
||||
&vb,
|
||||
)?,
|
||||
conv2d_block(
|
||||
&[
|
||||
(256, 512, "features.17"),
|
||||
(512, 512, "features.19"),
|
||||
(512, 512, "features.21"),
|
||||
],
|
||||
&vb,
|
||||
)?,
|
||||
conv2d_block(
|
||||
&[
|
||||
(512, 512, "features.24"),
|
||||
(512, 512, "features.26"),
|
||||
(512, 512, "features.28"),
|
||||
],
|
||||
&vb,
|
||||
)?,
|
||||
fully_connected(
|
||||
num_classes,
|
||||
PreLogitConfig {
|
||||
in_dim: (4096, 512, 7, 7),
|
||||
target_in: 4096,
|
||||
target_out: 512 * 7 * 7,
|
||||
},
|
||||
PreLogitConfig {
|
||||
in_dim: (4096, 4096, 1, 1),
|
||||
target_in: 4096,
|
||||
target_out: 4096,
|
||||
},
|
||||
vb.clone(),
|
||||
)?,
|
||||
];
|
||||
Ok(blocks)
|
||||
}
|
||||
|
||||
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
|
||||
let num_classes = 1000;
|
||||
let blocks = vec![
|
||||
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
|
||||
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
|
||||
conv2d_block(
|
||||
&[
|
||||
(128, 256, "features.10"),
|
||||
(256, 256, "features.12"),
|
||||
(256, 256, "features.14"),
|
||||
(256, 256, "features.16"),
|
||||
],
|
||||
&vb,
|
||||
)?,
|
||||
conv2d_block(
|
||||
&[
|
||||
(256, 512, "features.19"),
|
||||
(512, 512, "features.21"),
|
||||
(512, 512, "features.23"),
|
||||
(512, 512, "features.25"),
|
||||
],
|
||||
&vb,
|
||||
)?,
|
||||
conv2d_block(
|
||||
&[
|
||||
(512, 512, "features.28"),
|
||||
(512, 512, "features.30"),
|
||||
(512, 512, "features.32"),
|
||||
(512, 512, "features.34"),
|
||||
],
|
||||
&vb,
|
||||
)?,
|
||||
fully_connected(
|
||||
num_classes,
|
||||
PreLogitConfig {
|
||||
in_dim: (4096, 512, 7, 7),
|
||||
target_in: 4096,
|
||||
target_out: 512 * 7 * 7,
|
||||
},
|
||||
PreLogitConfig {
|
||||
in_dim: (4096, 4096, 1, 1),
|
||||
target_in: 4096,
|
||||
target_out: 4096,
|
||||
},
|
||||
vb.clone(),
|
||||
)?,
|
||||
];
|
||||
Ok(blocks)
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
use super::Config;
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
|
||||
@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
//
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
@ -53,6 +27,7 @@ fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
#[derive(Debug, Clone)]
|
||||
struct MultiHeadAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
@ -162,6 +137,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
@ -241,6 +217,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioEncoder {
|
||||
conv1: Conv1d,
|
||||
conv2: Conv1d,
|
||||
@ -316,6 +293,7 @@ impl AudioEncoder {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
positional_embedding: Tensor,
|
||||
@ -380,6 +358,7 @@ impl TextDecoder {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Whisper {
|
||||
pub encoder: AudioEncoder,
|
||||
pub decoder: TextDecoder,
|
||||
|
@ -19,6 +19,7 @@ fn conv1d(
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
#[derive(Debug, Clone)]
|
||||
struct MultiHeadAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
@ -128,6 +129,7 @@ impl MultiHeadAttention {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
@ -206,6 +208,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AudioEncoder {
|
||||
conv1: Conv1d,
|
||||
conv2: Conv1d,
|
||||
@ -281,6 +284,7 @@ impl AudioEncoder {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
positional_embedding: Tensor,
|
||||
@ -347,6 +351,7 @@ impl TextDecoder {
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Whisper {
|
||||
pub encoder: AudioEncoder,
|
||||
pub decoder: TextDecoder,
|
||||
|
@ -14,6 +14,13 @@ impl Embedding {
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn from_weights(weights: Tensor) -> Result<Self> {
|
||||
let (_in_size, out_size) = weights.dims2()?;
|
||||
let inner = candle_nn::Embedding::new(weights, out_size);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
self.inner.embeddings()
|
||||
}
|
||||
|
@ -77,6 +77,16 @@ impl VarBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
|
||||
let path = self.path(name);
|
||||
match self.data.get(&path) {
|
||||
None => {
|
||||
candle::bail!("cannot find tensor {name}")
|
||||
}
|
||||
Some(qtensor) => Ok(qtensor.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
|
31
candle-wasm-examples/blip/Cargo.toml
Normal file
31
candle-wasm-examples/blip/Cargo.toml
Normal file
@ -0,0 +1,31 @@
|
||||
[package]
|
||||
name = "candle-wasm-example-blip"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
|
||||
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
|
||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||
num-traits = { workspace = true }
|
||||
|
||||
# App crates.
|
||||
anyhow = { workspace = true }
|
||||
byteorder = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
image = { workspace = true }
|
||||
log = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen = "0.2.87"
|
||||
js-sys = "0.3.64"
|
23
candle-wasm-examples/blip/README.md
Normal file
23
candle-wasm-examples/blip/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
## Running [BLIP Image Captioning](https://huggingface.co/Salesforce/blip-image-captioning-large) Example
|
||||
### Vanilla JS and WebWorkers
|
||||
|
||||
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
|
||||
|
||||
```bash
|
||||
sh build-lib.sh
|
||||
```
|
||||
|
||||
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
|
||||
|
||||
```js
|
||||
import init, { Model } from "./build/m.js";
|
||||
```
|
||||
|
||||
The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything.
|
||||
Finally, you can preview the example by running a local HTTP server. For example:
|
||||
|
||||
```bash
|
||||
python -m http.server
|
||||
```
|
||||
|
||||
Then open `http://localhost:8000/index.html` in your browser.
|
77
candle-wasm-examples/blip/blipWorker.js
Normal file
77
candle-wasm-examples/blip/blipWorker.js
Normal file
@ -0,0 +1,77 @@
|
||||
import init, { Model } from "./build/m.js";
|
||||
|
||||
async function fetchArrayBuffer(url, cacheFile = true) {
|
||||
if (!cacheFile) return new Uint8Array(await (await fetch(url)).arrayBuffer());
|
||||
const cacheName = "blip-candle-cache";
|
||||
const cache = await caches.open(cacheName);
|
||||
const cachedResponse = await cache.match(url);
|
||||
if (cachedResponse) {
|
||||
const data = await cachedResponse.arrayBuffer();
|
||||
return new Uint8Array(data);
|
||||
}
|
||||
const res = await fetch(url, { cache: "force-cache" });
|
||||
cache.put(url, res.clone());
|
||||
return new Uint8Array(await res.arrayBuffer());
|
||||
}
|
||||
class Blip {
|
||||
static instance = {};
|
||||
|
||||
static async getInstance(
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
quantized
|
||||
) {
|
||||
if (!this.instance[modelID]) {
|
||||
await init();
|
||||
|
||||
self.postMessage({ status: "loading", message: "Loading Model" });
|
||||
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
|
||||
await Promise.all([
|
||||
fetchArrayBuffer(weightsURL),
|
||||
fetchArrayBuffer(tokenizerURL),
|
||||
fetchArrayBuffer(configURL),
|
||||
]);
|
||||
|
||||
this.instance[modelID] = new Model(
|
||||
weightsArrayU8,
|
||||
tokenizerArrayU8,
|
||||
configArrayU8,
|
||||
quantized
|
||||
);
|
||||
} else {
|
||||
self.postMessage({ status: "ready", message: "Model Already Loaded" });
|
||||
}
|
||||
return this.instance[modelID];
|
||||
}
|
||||
}
|
||||
|
||||
self.addEventListener("message", async (event) => {
|
||||
const { weightsURL, tokenizerURL, configURL, modelID, imageURL, quantized } =
|
||||
event.data;
|
||||
try {
|
||||
self.postMessage({ status: "status", message: "Loading Blip Model..." });
|
||||
const model = await Blip.getInstance(
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
quantized
|
||||
);
|
||||
self.postMessage({
|
||||
status: "status",
|
||||
message: "Running Blip Inference...",
|
||||
});
|
||||
const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
|
||||
const output = model.generate_caption_from_image(imageArrayU8);
|
||||
|
||||
self.postMessage({
|
||||
status: "complete",
|
||||
message: "complete",
|
||||
output: output,
|
||||
});
|
||||
} catch (e) {
|
||||
self.postMessage({ error: e });
|
||||
}
|
||||
});
|
2
candle-wasm-examples/blip/build-lib.sh
Executable file
2
candle-wasm-examples/blip/build-lib.sh
Executable file
@ -0,0 +1,2 @@
|
||||
cargo build --target wasm32-unknown-unknown --release
|
||||
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
|
393
candle-wasm-examples/blip/index.html
Normal file
393
candle-wasm-examples/blip/index.html
Normal file
@ -0,0 +1,393 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<style>
|
||||
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
|
||||
html,
|
||||
body {
|
||||
font-family: "Source Sans 3", sans-serif;
|
||||
}
|
||||
</style>
|
||||
<title>Candle Blip Image Captioning Demo</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script type="module" src="./code.js"></script>
|
||||
<script type="module">
|
||||
const MODELS = {
|
||||
blip_image_quantized_q4k: {
|
||||
base_url: "https://huggingface.co/lmz/candle-blip/resolve/main/",
|
||||
model: "blip-image-captioning-large-q4k.gguf",
|
||||
config: "config.json",
|
||||
tokenizer: "tokenizer.json",
|
||||
quantized: true,
|
||||
size: "271 MB",
|
||||
},
|
||||
blip_image_quantized_q80: {
|
||||
base_url: "https://huggingface.co/lmz/candle-blip/resolve/main/",
|
||||
model: "blip-image-captioning-large-q80.gguf",
|
||||
config: "config.json",
|
||||
tokenizer: "tokenizer.json",
|
||||
quantized: true,
|
||||
size: "505 MB",
|
||||
},
|
||||
blip_image_large: {
|
||||
base_url:
|
||||
"https://huggingface.co/Salesforce/blip-image-captioning-large/resolve/refs%2Fpr%2F18/",
|
||||
model: "model.safetensors",
|
||||
config: "config.json",
|
||||
tokenizer: "tokenizer.json",
|
||||
quantized: false,
|
||||
size: "1.88 GB",
|
||||
},
|
||||
};
|
||||
|
||||
const blipWorker = new Worker("./blipWorker.js", {
|
||||
type: "module",
|
||||
});
|
||||
|
||||
const outputStatusEl = document.querySelector("#output-status");
|
||||
const outputCaptionEl = document.querySelector("#output-caption");
|
||||
const modelSelectEl = document.querySelector("#model");
|
||||
const clearBtn = document.querySelector("#clear-btn");
|
||||
const fileUpload = document.querySelector("#file-upload");
|
||||
const dropArea = document.querySelector("#drop-area");
|
||||
const imagesExamples = document.querySelector("#image-select");
|
||||
const canvas = document.querySelector("#canvas");
|
||||
const ctxCanvas = canvas.getContext("2d");
|
||||
|
||||
let isCaptioning = false;
|
||||
let currentImageURL = null;
|
||||
clearBtn.addEventListener("click", () => {
|
||||
clearImageCanvas();
|
||||
});
|
||||
modelSelectEl.addEventListener("change", () => {
|
||||
if (currentImageURL) {
|
||||
runInference(currentImageURL);
|
||||
}
|
||||
});
|
||||
|
||||
//add event listener to file input
|
||||
fileUpload.addEventListener("input", async (e) => {
|
||||
const target = e.target;
|
||||
if (target.files.length > 0) {
|
||||
const href = URL.createObjectURL(target.files[0]);
|
||||
clearImageCanvas();
|
||||
await drawImageCanvas(href);
|
||||
runInference(href);
|
||||
}
|
||||
});
|
||||
// add event listener to drop-area
|
||||
dropArea.addEventListener("dragenter", (e) => {
|
||||
e.preventDefault();
|
||||
dropArea.classList.add("border-blue-700");
|
||||
});
|
||||
dropArea.addEventListener("dragleave", (e) => {
|
||||
e.preventDefault();
|
||||
dropArea.classList.remove("border-blue-700");
|
||||
});
|
||||
dropArea.addEventListener("dragover", (e) => {
|
||||
e.preventDefault();
|
||||
});
|
||||
dropArea.addEventListener("drop", async (e) => {
|
||||
e.preventDefault();
|
||||
dropArea.classList.remove("border-blue-700");
|
||||
const url = e.dataTransfer.getData("text/uri-list");
|
||||
const files = e.dataTransfer.files;
|
||||
|
||||
if (files.length > 0) {
|
||||
const href = URL.createObjectURL(files[0]);
|
||||
clearImageCanvas();
|
||||
await drawImageCanvas(href);
|
||||
runInference(href);
|
||||
} else if (url) {
|
||||
clearImageCanvas();
|
||||
await drawImageCanvas(url);
|
||||
runInference(url);
|
||||
}
|
||||
});
|
||||
|
||||
imagesExamples.addEventListener("click", async (e) => {
|
||||
if (isCaptioning) {
|
||||
return;
|
||||
}
|
||||
const target = e.target;
|
||||
if (target.nodeName === "IMG") {
|
||||
const href = target.src;
|
||||
clearImageCanvas();
|
||||
await drawImageCanvas(href);
|
||||
runInference(href);
|
||||
}
|
||||
});
|
||||
function clearImageCanvas() {
|
||||
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
|
||||
isCaptioning = false;
|
||||
clearBtn.disabled = true;
|
||||
canvas.parentElement.style.height = "auto";
|
||||
outputStatusEl.hidden = false;
|
||||
outputCaptionEl.hidden = true;
|
||||
outputStatusEl.innerText = "Please select an image";
|
||||
currentImageURL = null;
|
||||
}
|
||||
|
||||
async function drawImageCanvas(imgURL) {
|
||||
if (!imgURL) {
|
||||
throw new Error("No image URL provided");
|
||||
}
|
||||
return new Promise((resolve, reject) => {
|
||||
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
|
||||
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
|
||||
|
||||
const img = new Image();
|
||||
img.crossOrigin = "anonymous";
|
||||
img.onload = () => {
|
||||
canvas.width = img.width;
|
||||
canvas.height = img.height;
|
||||
ctxCanvas.drawImage(img, 0, 0);
|
||||
canvas.parentElement.style.height = canvas.offsetHeight + "px";
|
||||
clearBtn.disabled = false;
|
||||
resolve(img);
|
||||
};
|
||||
img.src = imgURL;
|
||||
currentImageURL = imgURL;
|
||||
});
|
||||
}
|
||||
|
||||
document.addEventListener("DOMContentLoaded", () => {
|
||||
for (const [id, model] of Object.entries(MODELS)) {
|
||||
const option = document.createElement("option");
|
||||
option.value = id;
|
||||
option.innerText = `${id} (${model.size})`;
|
||||
modelSelectEl.appendChild(option);
|
||||
}
|
||||
});
|
||||
async function getImageCaption(
|
||||
worker,
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
imageURL,
|
||||
quantized,
|
||||
updateStatus = null
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
worker.postMessage({
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
modelID,
|
||||
imageURL,
|
||||
quantized,
|
||||
});
|
||||
function messageHandler(event) {
|
||||
if ("error" in event.data) {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
reject(new Error(event.data.error));
|
||||
}
|
||||
if (event.data.status === "complete") {
|
||||
worker.removeEventListener("message", messageHandler);
|
||||
resolve(event.data);
|
||||
}
|
||||
if (updateStatus) updateStatus(event.data);
|
||||
}
|
||||
worker.addEventListener("message", messageHandler);
|
||||
});
|
||||
}
|
||||
function updateStatus(data) {
|
||||
if (data.status === "status") {
|
||||
outputStatusEl.innerText = data.message;
|
||||
}
|
||||
}
|
||||
async function runInference(imageURL) {
|
||||
if (isCaptioning || !imageURL) {
|
||||
alert("Please select an image first");
|
||||
return;
|
||||
}
|
||||
|
||||
outputStatusEl.hidden = false;
|
||||
outputCaptionEl.hidden = true;
|
||||
clearBtn.disabled = true;
|
||||
modelSelectEl.disabled = true;
|
||||
isCaptioning = true;
|
||||
const selectedModel = modelSelectEl.value;
|
||||
const model = MODELS[selectedModel];
|
||||
const weightsURL = `${model.base_url}${model.model}`;
|
||||
const tokenizerURL = `${model.base_url}${model.tokenizer}`;
|
||||
const configURL = `${model.base_url}${model.config}`;
|
||||
const quantized = model.quantized;
|
||||
try {
|
||||
const time = performance.now();
|
||||
const caption = await getImageCaption(
|
||||
blipWorker,
|
||||
weightsURL,
|
||||
tokenizerURL,
|
||||
configURL,
|
||||
selectedModel,
|
||||
imageURL,
|
||||
quantized,
|
||||
updateStatus
|
||||
);
|
||||
outputStatusEl.hidden = true;
|
||||
outputCaptionEl.hidden = false;
|
||||
const totalTime = ((performance.now() - time)/1000).toFixed(2);
|
||||
outputCaptionEl.innerHTML = `${
|
||||
caption.output
|
||||
}<br/><span class="text-xs">Inference time: ${totalTime} s</span>`;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
outputStatusEl.hidden = false;
|
||||
outputCaptionEl.hidden = true;
|
||||
outputStatusEl.innerText = err.message;
|
||||
}
|
||||
clearBtn.disabled = false;
|
||||
modelSelectEl.disabled = false;
|
||||
isCaptioning = false;
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body class="container max-w-4xl mx-auto p-4">
|
||||
<main class="grid grid-cols-1 gap-5 relative">
|
||||
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
|
||||
<div>
|
||||
<h1 class="text-5xl font-bold">Candle BLIP Image Captioning</h1>
|
||||
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
|
||||
<p class="max-w-lg">
|
||||
<a
|
||||
href="https://huggingface.co/Salesforce/blip-image-captioning-large"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>BLIP Image Captioning
|
||||
</a>
|
||||
running in the browser using
|
||||
<a
|
||||
href="https://github.com/huggingface/candle/"
|
||||
target="_blank"
|
||||
class="underline hover:text-blue-500 hover:no-underline"
|
||||
>Candle</a
|
||||
>, a minimalist ML framework for Rust.
|
||||
</p>
|
||||
<p class="text-xs max-w-lg py-2">
|
||||
<b>Note:</b>
|
||||
The image captioning on the smallest model takes about ~50 seconds, it
|
||||
will vary depending on your machine and model size.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="model" class="font-medium block">Models Options: </label>
|
||||
<select
|
||||
id="model"
|
||||
class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max"
|
||||
></select>
|
||||
</div>
|
||||
<!-- drag and drop area -->
|
||||
<div class="grid gap-4 sm:grid-cols-2 py-4">
|
||||
<div class="relative max-w-lg">
|
||||
<div
|
||||
class="absolute w-full bottom-full flex justify-between items-center"
|
||||
>
|
||||
<div class="flex gap-2 w-full">
|
||||
<button
|
||||
id="clear-btn"
|
||||
disabled
|
||||
title="Clear Image"
|
||||
class="ml-auto text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center"
|
||||
>
|
||||
<svg
|
||||
class=""
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 13 12"
|
||||
height="1em"
|
||||
>
|
||||
<path
|
||||
d="M1.6.7 12 11.1M12 .7 1.6 11.1"
|
||||
stroke="#2E3036"
|
||||
stroke-width="2"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
id="drop-area"
|
||||
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative aspect-video w-full overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="flex flex-col items-center justify-center space-y-1 text-center"
|
||||
>
|
||||
<svg
|
||||
width="25"
|
||||
height="25"
|
||||
viewBox="0 0 25 25"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path
|
||||
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
|
||||
fill="#000"
|
||||
/>
|
||||
</svg>
|
||||
<div class="flex text-sm text-gray-600">
|
||||
<label
|
||||
for="file-upload"
|
||||
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
|
||||
>
|
||||
<span>Drag and drop y our image here</span>
|
||||
<span class="block text-xs">or</span>
|
||||
<span class="block text-xs">Click to upload</span>
|
||||
</label>
|
||||
</div>
|
||||
<input
|
||||
id="file-upload"
|
||||
name="file-upload"
|
||||
type="file"
|
||||
class="sr-only"
|
||||
/>
|
||||
</div>
|
||||
<canvas
|
||||
id="canvas"
|
||||
class="absolute pointer-events-none w-full"
|
||||
></canvas>
|
||||
</div>
|
||||
</div>
|
||||
<div class="">
|
||||
<div
|
||||
class="h-full bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
|
||||
>
|
||||
<p
|
||||
id="output-caption"
|
||||
class="m-auto text-xl text-center p-2"
|
||||
hidden
|
||||
></p>
|
||||
<span id="output-status" class="m-auto font-light">
|
||||
Please select an image
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div
|
||||
class="flex gap-3 items-center overflow-x-scroll"
|
||||
id="image-select"
|
||||
>
|
||||
<h3 class="font-medium">Examples:</h3>
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
|
||||
class="cursor-pointer w-24 h-24 object-cover"
|
||||
/>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
|
||||
class="cursor-pointer w-24 h-24 object-cover"
|
||||
/>
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
|
||||
class="cursor-pointer w-24 h-24 object-cover"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
148
candle-wasm-examples/blip/src/bin/m.rs
Normal file
148
candle-wasm-examples/blip/src/bin/m.rs
Normal file
@ -0,0 +1,148 @@
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::blip;
|
||||
use candle_transformers::models::quantized_blip;
|
||||
use candle_wasm_example_blip::console_log;
|
||||
use candle_wasm_example_blip::token_output_stream::TokenOutputStream;
|
||||
use js_sys::Date;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
enum SelectedModel {
|
||||
M(blip::BlipForConditionalGeneration),
|
||||
Q(quantized_blip::BlipForConditionalGeneration),
|
||||
}
|
||||
|
||||
impl SelectedModel {
|
||||
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor, JsError> {
|
||||
match self {
|
||||
Self::M(m) => m
|
||||
.text_decoder()
|
||||
.forward(xs, img_xs)
|
||||
.map_err(|e| JsError::new(&e.to_string())),
|
||||
Self::Q(m) => m
|
||||
.text_decoder()
|
||||
.forward(xs, img_xs)
|
||||
.map_err(|e| JsError::new(&e.to_string())),
|
||||
}
|
||||
}
|
||||
fn reset_kv_cache(&mut self) {
|
||||
match self {
|
||||
Self::M(m) => m.reset_kv_cache(),
|
||||
Self::Q(m) => m.reset_kv_cache(),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
pub struct Model {
|
||||
model: SelectedModel,
|
||||
tokenizer: TokenOutputStream,
|
||||
}
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl Model {
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn load(
|
||||
weights: Vec<u8>,
|
||||
tokenizer: Vec<u8>,
|
||||
config: Vec<u8>,
|
||||
quantized: bool,
|
||||
) -> Result<Model, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let tokenizer = TokenOutputStream::new(tokenizer);
|
||||
|
||||
let config: blip::Config = serde_json::from_slice(&config)?;
|
||||
let device = Device::Cpu;
|
||||
|
||||
let start = Date::now();
|
||||
let model: SelectedModel = if quantized {
|
||||
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
SelectedModel::Q(model)
|
||||
} else {
|
||||
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, &device)?;
|
||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
SelectedModel::M(model)
|
||||
};
|
||||
|
||||
console_log!("model loaded in {:?}s", (Date::now() - start) / 1000.);
|
||||
Ok(Self { model, tokenizer })
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
pub fn generate_caption_from_image(&mut self, image: Vec<u8>) -> Result<String, JsError> {
|
||||
self.model.reset_kv_cache();
|
||||
|
||||
let device = Device::Cpu;
|
||||
console_log!("loading image as tensor");
|
||||
let start = Date::now();
|
||||
let image: Tensor = self.load_image(image)?.to_device(&device)?;
|
||||
console_log!("image loaded in {:?}s", (Date::now() - start) / 1000.);
|
||||
let start = Date::now();
|
||||
let image_embeds: Tensor = match &mut self.model {
|
||||
SelectedModel::M(m) => image.unsqueeze(0)?.apply(m.vision_model())?,
|
||||
SelectedModel::Q(m) => image.unsqueeze(0)?.apply(m.vision_model())?,
|
||||
};
|
||||
console_log!("image embedded in {:?}s", (Date::now() - start) / 1000.);
|
||||
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
|
||||
let mut token_ids = vec![30522u32];
|
||||
let mut text: String = "".to_string();
|
||||
|
||||
let start = Date::now();
|
||||
for index in 0..1000 {
|
||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = self.model.text_decoder_forward(&input_ids, &image_embeds)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
if let Some(t) = self.tokenizer.next_token(token)? {
|
||||
text.push_str(&t);
|
||||
}
|
||||
}
|
||||
if let Some(rest) = self
|
||||
.tokenizer
|
||||
.decode_rest()
|
||||
.map_err(|m| JsError::new(&m.to_string()))?
|
||||
{
|
||||
text.push_str(&rest);
|
||||
}
|
||||
console_log!("caption generated in {:?}s", (Date::now() - start) / 1000.);
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn load_image(&self, image: Vec<u8>) -> Result<Tensor, JsError> {
|
||||
let device = &Device::Cpu;
|
||||
let img = image::io::Reader::new(std::io::Cursor::new(image))
|
||||
.with_guessed_format()?
|
||||
.decode()
|
||||
.map_err(|e| JsError::new(&e.to_string()))?
|
||||
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
|
||||
let img = img.to_rgb8();
|
||||
let data = img.into_raw();
|
||||
let data = Tensor::from_vec(data, (384, 384, 3), device)?.permute((2, 0, 1))?;
|
||||
let mean =
|
||||
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?;
|
||||
let std =
|
||||
Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?;
|
||||
(data.to_dtype(candle::DType::F32)? / 255.)?
|
||||
.broadcast_sub(&mean)?
|
||||
.broadcast_div(&std)
|
||||
.map_err(|e| JsError::new(&e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
17
candle-wasm-examples/blip/src/lib.rs
Normal file
17
candle-wasm-examples/blip/src/lib.rs
Normal file
@ -0,0 +1,17 @@
|
||||
use wasm_bindgen::prelude::*;
|
||||
pub mod token_output_stream;
|
||||
|
||||
#[wasm_bindgen]
|
||||
extern "C" {
|
||||
// Use `js_namespace` here to bind `console.log(..)` instead of just
|
||||
// `log(..)`
|
||||
#[wasm_bindgen(js_namespace = console)]
|
||||
pub fn log(s: &str);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! console_log {
|
||||
// Note that this is using the `log` function imported above during
|
||||
// `bare_bones`
|
||||
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
|
||||
}
|
86
candle-wasm-examples/blip/src/token_output_stream.rs
Normal file
86
candle-wasm-examples/blip/src/token_output_stream.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use candle::Result;
|
||||
|
||||
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
|
||||
/// streaming way rather than having to wait for the full decoding.
|
||||
pub struct TokenOutputStream {
|
||||
tokenizer: tokenizers::Tokenizer,
|
||||
tokens: Vec<u32>,
|
||||
prev_index: usize,
|
||||
current_index: usize,
|
||||
}
|
||||
|
||||
impl TokenOutputStream {
|
||||
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
tokens: Vec::new(),
|
||||
prev_index: 0,
|
||||
current_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> tokenizers::Tokenizer {
|
||||
self.tokenizer
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: &[u32]) -> Result<String> {
|
||||
match self.tokenizer.decode(tokens, true) {
|
||||
Ok(str) => Ok(str),
|
||||
Err(err) => candle::bail!("cannot decode: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
|
||||
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_rest(&self) -> Result<Option<String>> {
|
||||
let prev_text = if self.tokens.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let tokens = &self.tokens[self.prev_index..self.current_index];
|
||||
self.decode(tokens)?
|
||||
};
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
Ok(Some(text.1.to_string()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_all(&self) -> Result<String> {
|
||||
self.decode(&self.tokens)
|
||||
}
|
||||
|
||||
pub fn get_token(&self, token_s: &str) -> Option<u32> {
|
||||
self.tokenizer.get_vocab(true).get(token_s).copied()
|
||||
}
|
||||
|
||||
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
|
||||
&self.tokenizer
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.tokens.clear();
|
||||
self.prev_index = 0;
|
||||
self.current_index = 0;
|
||||
}
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use candle_nn::{
|
||||
embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@ -57,20 +59,6 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.dim))
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -198,7 +186,7 @@ impl Mlp {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
@ -283,7 +271,7 @@ impl Llama {
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
|
Reference in New Issue
Block a user