mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
More Model Module Docs (#2623)
* dinov2 * add another example * ad dinov2reg4 * eva2 * efficientvit * moondream * update t5 * update t5 * rwkv * stable diffusion docs * add wasm link * add segment_anything * adjsut for clippy * ignore bertdoc * dinov2 ignore * update block to be text * remove the rust blocks for the moment * bump python to 3.11 * add a setup-python step * add py311 to test as well
This commit is contained in:
6
.github/workflows/rust-ci.yml
vendored
6
.github/workflows/rust-ci.yml
vendored
@ -16,6 +16,9 @@ jobs:
|
||||
rust: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
@ -35,6 +38,9 @@ jobs:
|
||||
rust: [stable]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
|
@ -7,56 +7,6 @@
|
||||
//! - Upstream [Github repo](https://github.com/google-research/bert).
|
||||
//! - See bert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
|
||||
//!
|
||||
//! ```no_run
|
||||
//! // for sentence embeddings
|
||||
//! # use candle_core::Tensor;
|
||||
//! # use candle_nn::{VarBuilder, Module};
|
||||
//! # fn main() -> candle_core::Result<()> {
|
||||
//! # let model = todo!();
|
||||
//! # let prompt = "Here is a test sentence";
|
||||
//! let embeddings = model.forward(prompt)?;
|
||||
//! // Returns tensor of shape [1, 7, 384]
|
||||
//! println!("{embeddings}");
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//!
|
||||
//! // Different models can be loaded using the model ID
|
||||
//! # use candle_core::Tensor;
|
||||
//! # use candle_nn::{VarBuilder, Module};
|
||||
//! # fn main() -> candle_core::Result<()> {
|
||||
//! # let vb = todo!();
|
||||
//! # let config = todo!();
|
||||
//! let model = BertModel::load(vb, &config )?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//!
|
||||
//! // Gelu approximation
|
||||
//! // You can get a speedup by configuring the model
|
||||
//! // to use an approximation of the gelu activation:
|
||||
//! # use candle_core::Tensor;
|
||||
//! # use candle_nn::{VarBuilder, Module};
|
||||
//! # fn main() -> candle_core::Result<()> {
|
||||
//! # let mut config = todo!();
|
||||
//! config.hidden_act = HiddenAct::GeluApproximate;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//!
|
||||
//! // Similarities
|
||||
//! // Bert can compute sentence embeddings which can then be used to calculate
|
||||
//! // semantic similarities between sentences through cosine similarity scoring.
|
||||
//! // The sentence embeddings are computed using average pooling across all tokens.
|
||||
//! # use candle_core::Tensor;
|
||||
//! # use candle_nn::{VarBuilder, Module};
|
||||
//! # fn main() -> candle_core::Result<()> {
|
||||
//! # let model = todo!();
|
||||
//! let sentence1 = "The new movie is awesome";
|
||||
//! let sentence2 = "The new movie is so great";
|
||||
//! let emb1 = model.forward(sentence1)?;
|
||||
//! let emb2 = model.forward(sentence2)?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
//!
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
|
@ -1,8 +1,42 @@
|
||||
//! Implementation of the DINOv2 models from Meta Research.
|
||||
//!
|
||||
//! See:
|
||||
//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
|
||||
//! This module implements the DINOv2 vision transformer model from Meta AI Research.
|
||||
//! DINOv2 is a self-supervised learning model that can learn visual features
|
||||
//! without using any labeled data. See: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
|
||||
//!
|
||||
//! ## Running an example with color map and CUDA
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run \
|
||||
//! --features cuda,depth_anything_v2 \
|
||||
//! --package candle-examples \
|
||||
//! --example depth_anything_v2 \
|
||||
//! -- --color-map \
|
||||
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
//! ```
|
||||
//!
|
||||
//! ## Running as an ImageNet classifier
|
||||
//!
|
||||
//! The model returns the probability for the image to belong to each of the 1000 ImageNet categories.
|
||||
//!
|
||||
//! <div align=center>
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
|
||||
//! </div>
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run \
|
||||
//! --example dinov2 \
|
||||
//! --release \
|
||||
//! -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
//!
|
||||
//! > mountain bike, all-terrain bike, off-roader: 43.67%
|
||||
//! > bicycle-built-for-two, tandem bicycle, tandem: 33.20%
|
||||
//! > crash helmet : 13.23%
|
||||
//! > unicycle, monocycle : 2.44%
|
||||
//! > maillot : 2.42%
|
||||
//! ```
|
||||
//!
|
||||
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
|
@ -1,9 +1,34 @@
|
||||
//! Implementation of the DINOv2 revision (4 regularization)
|
||||
//!
|
||||
//! See:
|
||||
//! - DINOv2: ["DINOv2: Learning Robust Visual Features without Supervision"](https://github.com/facebookresearch/dinov2)
|
||||
//! The DINOv2-reg4 model is a variant of DINOv2 that adds 4 regularization tokens to the
|
||||
//! original architecture. This implementation is specifically trained for plant species
|
||||
//! classification on the PlantCLEF2024 dataset with 7,806 classes.
|
||||
//!
|
||||
//! This code implements the regularization tokens version with 4 regularization tokens.
|
||||
//! - [Paper](https://arxiv.org/abs/2309.16588). DINOv2: Learning Robust Visual Features without Supervision
|
||||
//! - [GH Repo](https://github.com/facebookresearch/dinov2)
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```bash
|
||||
//! # Download classes names and a plant picture to identify
|
||||
//! # see candle/examples/dinov2reg4 for full code.
|
||||
//!
|
||||
//! # Perform inference
|
||||
//! cargo run \
|
||||
//! --example dinov2reg4 \
|
||||
//! --release -- \
|
||||
//! --image <orchid-file>
|
||||
//!
|
||||
//! > Orchis simia Lam. : 45.55%
|
||||
//! > Orchis × bergonii Nanteuil: 9.80%
|
||||
//! > Orchis italica Poir. : 9.66%
|
||||
//! > Orchis × angusticruris Franch.: 2.76%
|
||||
//! > Orchis × bivonae Tod. : 2.54%
|
||||
//! ```
|
||||
//!
|
||||
//! <div align=center>
|
||||
//! <img src="https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c" alt="" width=320>
|
||||
//! </div>
|
||||
//!
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
@ -1,9 +1,40 @@
|
||||
//! EfficientViT (MSRA) inference implementation based on timm.
|
||||
//!
|
||||
//! See ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027)
|
||||
//! This crate provides an implementation of the EfficientViT model from Microsoft Research Asia
|
||||
//! for efficient image classification. The model uses cascaded group attention modules
|
||||
//! to achieve strong performance while maintaining low memory usage.
|
||||
//!
|
||||
//! The model was originally described in the paper:
|
||||
//! ["EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"](https://arxiv.org/abs/2305.07027)
|
||||
//!
|
||||
//! This implementation is based on the reference implementation from
|
||||
//! [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py).
|
||||
//!
|
||||
//! # Example Usage
|
||||
//!
|
||||
//! This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
|
||||
//! The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
//!
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run
|
||||
//! --example efficientvit \
|
||||
//! --release -- \
|
||||
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
|
||||
//!
|
||||
//! > loaded image Tensor[dims 3, 224, 224; f32]
|
||||
//! > model built
|
||||
//! > mountain bike, all-terrain bike, off-roader: 69.80%
|
||||
//! > unicycle, monocycle : 13.03%
|
||||
//! > bicycle-built-for-two, tandem bicycle, tandem: 9.28%
|
||||
//! > crash helmet : 2.25%
|
||||
//! > alp : 0.46%
|
||||
//! ```
|
||||
//!
|
||||
//! <div align=center>
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
|
||||
//! </div>
|
||||
//!
|
||||
//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py)
|
||||
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
|
||||
|
@ -1,9 +1,31 @@
|
||||
//! EVA-2 inference implementation.
|
||||
//!
|
||||
//! See ["EVA-02: A Visual Representation for Neon Genesis"](https://arxiv.org/abs/2303.11331)
|
||||
//! EVA-02 is a computer vision model that can be used as an ImageNet classifier.
|
||||
//! The model returns the probability for an image to belong to each of the 1000
|
||||
//! ImageNet categories.
|
||||
//!
|
||||
//! - [Paper](https://arxiv.org/abs/2303.11331). EVA-02: A Visual Representation for Neon Genesis
|
||||
//! - [Code](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py)
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run \
|
||||
//! --example eva2 \
|
||||
//! --release -- \
|
||||
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
//!
|
||||
//! > mountain bike, all-terrain bike, off-roader: 37.09%
|
||||
//! > maillot : 8.30%
|
||||
//! > alp : 2.13%
|
||||
//! > bicycle-built-for-two, tandem bicycle, tandem: 0.84%
|
||||
//! > crash helmet : 0.73%
|
||||
//! ```
|
||||
//!
|
||||
//! <div align=center>
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width=640>
|
||||
//! </div>
|
||||
//!
|
||||
//! Based on implementation from [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/eva2.py)
|
||||
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
|
@ -1,13 +1,39 @@
|
||||
//! MoonDream Model vision-to-text
|
||||
//!
|
||||
//!
|
||||
//! Moondream is a computer-vision model that can answer real-world questions about images.
|
||||
//! It's lightweight with only 1.6B parameters, enabling it to run on mobile phones and edge devices.
|
||||
//! [MoonDream Original Implementation](https://github.com/vikhyat/moondream)
|
||||
//!
|
||||
//! The model consists of:
|
||||
//! - Vision encoder using a ViT-style architecture
|
||||
//! - Text decoder based on Microsoft's Phi model
|
||||
//! - Vision projection module to align vision and text embeddings
|
||||
//!
|
||||
//! References:
|
||||
//! - [MoonDream Original Implementation](https://github.com/vikhyat/moondream)
|
||||
//! # Examples
|
||||
//!
|
||||
//! <img src="https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg" width="200">
|
||||
//!
|
||||
//! ```bash
|
||||
//! # download an example image
|
||||
//! wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg
|
||||
//!
|
||||
//! # Now you can run Moondream from the `candle-examples` crate:
|
||||
//! cargo run --example moondream \
|
||||
//! --release -- \
|
||||
//! --prompt "What is the girl eating?"
|
||||
//! --image "./demo-1.jpg"
|
||||
//!
|
||||
//! > avavx: false, neon: true, simd128: false, f16c: false
|
||||
//! > temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||
//! > retrieved the files in 3.395583ms
|
||||
//! > Running on CPU, to run on GPU(metal), build this example with `--features metal`
|
||||
//! > loaded the model in 5.485493792s
|
||||
//! > loaded and encoded the image Tensor[dims 3, 378, 378; f32] in 4.801396417s
|
||||
//! > starting the inference loop
|
||||
//! > The girl is eating a hamburger.<
|
||||
//! > 9 tokens generated (0.68 token/s)
|
||||
//! ```
|
||||
|
||||
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
|
||||
use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
|
||||
|
@ -1,7 +1,9 @@
|
||||
//! RWKV v5 model implementation.
|
||||
//!
|
||||
//! RWKV is an RNN with transformer-level performance that can be implemented
|
||||
//! as either a transformer or RNN.
|
||||
//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
//! with performance on par with transformer architectures. Several variants are
|
||||
//! available, candle implements the v5 and v6 versions and can be used with
|
||||
//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
//!
|
||||
//! Key characteristics:
|
||||
//! - Time-mix attention mechanism
|
||||
@ -14,6 +16,20 @@
|
||||
//! - [RWKV Language Model](https://github.com/BlinkDL/RWKV-LM)
|
||||
//! - [RWKV v5 Release](https://github.com/BlinkDL/ChatRWKV/tree/main)
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run --example rwkv --release -- \
|
||||
//! --prompt "The smallest prime is "
|
||||
//!
|
||||
//! > avx: true, neon: false, simd128: false, f16c: true
|
||||
//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
//! > The smallest prime is ϕ(2) = 2.
|
||||
//! > The smallest composite is ϕ(3) = 3.
|
||||
//! > The smallest perfect number is ϕ(5) = 5.
|
||||
//! > The smallest perfect square is ϕ(4) = 4.
|
||||
//! > The smallest perfect cube is ϕ(6) = 6.
|
||||
//! ```
|
||||
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
|
@ -1,7 +1,9 @@
|
||||
//! RWKV v6 model implementation.
|
||||
//!
|
||||
//! RWKV is an RNN with transformer-like performance.
|
||||
//! Version 6 introduces refinements to the architecture.
|
||||
//! The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
//! with performance on par with transformer architectures. Several variants are
|
||||
//! available, candle implements the v5 and v6 versions and can be used with
|
||||
//! Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
//!
|
||||
//! Key characteristics:
|
||||
//! - Linear attention mechanism
|
||||
@ -10,9 +12,20 @@
|
||||
//! - Feed forward gating
|
||||
//! - State recycling for efficient inference
|
||||
//!
|
||||
//! References:
|
||||
//! - [RWKV Model](https://github.com/BlinkDL/RWKV-LM)
|
||||
//! # Example
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run --example rwkv --release -- \
|
||||
//! --prompt "The smallest prime is "
|
||||
//!
|
||||
//! > avx: true, neon: false, simd128: false, f16c: true
|
||||
//! > temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
//! > The smallest prime is ϕ(2) = 2.
|
||||
//! > The smallest composite is ϕ(3) = 3.
|
||||
//! > The smallest perfect number is ϕ(5) = 5.
|
||||
//! > The smallest perfect square is ϕ(4) = 4.
|
||||
//! > The smallest perfect cube is ϕ(6) = 6.
|
||||
//! ```
|
||||
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
|
@ -1,10 +1,33 @@
|
||||
//! Segment Anything Model (SAM)
|
||||
//!
|
||||
//! SAM is an architecture for image segmentation, capable of segmenting any object
|
||||
//! in an image based on prompts like points or boxes.
|
||||
//! in an image based on prompts like points or boxes. //! This model provides a robust and fast image segmentation pipeline that can be tweaked via
|
||||
//! some prompting (requesting some points to be in the target mask, requesting some
|
||||
//! points to be part of the background so _not_ in the target mask, specifying some
|
||||
//! bounding box).
|
||||
//!
|
||||
//! - [GH Link](https://github.com/facebookresearch/segment-anything)
|
||||
//! - [Paper](https://arxiv.org/abs/2304.02643)
|
||||
//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/candle-segment-anything-wasm)
|
||||
//! - 💻 [GH Link](https://github.com/facebookresearch/segment-anything)
|
||||
//! - 📝 [Paper](https://arxiv.org/abs/2304.02643)
|
||||
//! - 💡 The default backbone can be replaced by the smaller and faster TinyViT model based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
//!
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run --example segment-anything --release -- \
|
||||
//! --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
//! --use-tiny --point 0.6,0.6 --point 0.6,0.55
|
||||
//! ```
|
||||
//!
|
||||
//! <div align=center style="display: flex; justify-content: center; gap: 10px;">
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/yolo-v8/assets/bike.jpg" alt="" width="30%">
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/single_pt_prompt.jpg" alt="" width="30%">
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/two_pt_prompt.jpg" alt="" width="30%">
|
||||
//! </div>
|
||||
//!
|
||||
//!
|
||||
//! > Original; Prompt with `--point 0.6,0.55`; Prompt with `--point 0.6,0.6 --point 0.6,0.55`
|
||||
//!
|
||||
pub use crate::models::with_tracing::Linear;
|
||||
use candle::{Result, Tensor};
|
||||
|
@ -5,7 +5,37 @@
|
||||
//!
|
||||
//! - [Original Repository](https://github.com/CompVis/stable-diffusion)
|
||||
//! - [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
//! - The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising Diffusion Implicit Model scheduler (DDIM). The original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
|
||||
//!
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! <div align=center>
|
||||
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" alt="rusty robot holding a candle" width=320>
|
||||
//! </div>
|
||||
//!
|
||||
//! _"A rusty robot holding a fire torch in its hand."_ Generated by Stable Diffusion XL using Rust and [candle](https://github.com/huggingface/candle).
|
||||
//!
|
||||
//! ```bash
|
||||
//! # example running with cuda
|
||||
//! # see the candle-examples/examples/stable-diffusion for all options
|
||||
//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
||||
//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
|
||||
//!
|
||||
//! # with sd-turbo
|
||||
//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
||||
//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \
|
||||
//! --sd-version turbo
|
||||
//!
|
||||
//! # with flash attention.
|
||||
//! # feature flag: `--features flash-attn`
|
||||
//! # cli flag: `--use-flash-attn`.
|
||||
//! # flash-attention-v2 is only compatible with Ampere, Ada, \
|
||||
//! # or Hopper GPUs (e.g., A100/H100, RTX 3090/4090).
|
||||
//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
|
||||
//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \
|
||||
//! --use-flash-attn
|
||||
//! ```
|
||||
|
||||
pub mod attention;
|
||||
pub mod clip;
|
||||
|
@ -14,6 +14,49 @@
|
||||
//! - [T5 Paper](https://arxiv.org/abs/1910.10683)
|
||||
//! - [HuggingFace T5](https://huggingface.co/docs/transformers/model_doc/t5)
|
||||
//! - [GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py)
|
||||
//!
|
||||
//! # Encoder-decoder example:
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run --example t5 --release -- \
|
||||
//! --model-id "t5-small" \
|
||||
//! --prompt "translate to German: A beautiful candle." \
|
||||
//! --decode
|
||||
//! > ...
|
||||
//! > Eine schöne Kerze.
|
||||
//! > 9 tokens generated (2.42 token/s)
|
||||
//! ```
|
||||
//!
|
||||
//! Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
|
||||
//!
|
||||
//! # Translation with MADLAD
|
||||
//!
|
||||
//!
|
||||
//! [MADLAD-400](https://arxiv.org/abs/2309.04662) is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run --example t5 --release -- \
|
||||
//! --model-id "jbochi/madlad400-3b-mt" \
|
||||
//! --prompt "<2de> How are you, my friend?" \
|
||||
//! --decode --temperature 0
|
||||
//! ...
|
||||
//! Wie geht es dir, mein Freund?
|
||||
//! ```
|
||||
//!
|
||||
//! ## Sentence embedding example
|
||||
//!
|
||||
//! ```bash
|
||||
//! cargo run --example t5 --release -- \
|
||||
//! --model-id "t5-small" --prompt "A beautiful candle."
|
||||
//! ...
|
||||
//! [[[ 0.0515, -0.0541, -0.0761, ..., -0.0392, 0.1511, -0.0265],
|
||||
//! [-0.0974, 0.0998, -0.1659, ..., -0.2450, 0.1738, -0.0164],
|
||||
//! [ 0.0624, -0.1024, 0.0430, ..., -0.1388, 0.0564, -0.2962],
|
||||
//! [-0.0389, -0.1173, 0.0026, ..., 0.1064, -0.1065, 0.0990],
|
||||
//! [ 0.1300, 0.0027, -0.0326, ..., 0.0026, -0.0317, 0.0851]]]
|
||||
//! Tensor[[1, 5, 512], f32]
|
||||
//! Took 303.766583ms
|
||||
//! ```
|
||||
|
||||
use crate::models::with_tracing::Embedding;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
|
Reference in New Issue
Block a user