Module Docs (#2620)

* update bert docs

* update based

* update bigcode

* add pixtral

* add flux as well
This commit is contained in:
zachcp
2024-11-16 03:09:17 -05:00
committed by GitHub
parent 00d8a0c178
commit a3f200e369
5 changed files with 126 additions and 10 deletions

View File

@ -1,9 +1,9 @@
//! Based from the Stanford Hazy Research group. //! Based from the Stanford Hazy Research group.
//! //!
//! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024 //! See "Simple linear attention language models balance the recall-throughput tradeoff", Arora et al. 2024
//! - [Arxiv](https://arxiv.org/abs/2402.18668) //! - Simple linear attention language models balance the recall-throughput tradeoff. [Arxiv](https://arxiv.org/abs/2402.18668)
//! - [Github](https://github.com/HazyResearch/based) //! - [Github Rep](https://github.com/HazyResearch/based)
//! //! - [Blogpost](https://hazyresearch.stanford.edu/blog/2024-03-03-based)
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{ use candle_nn::{

View File

@ -1,8 +1,61 @@
//! BERT (Bidirectional Encoder Representations from Transformers) //! BERT (Bidirectional Encoder Representations from Transformers)
//! //!
//! See "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding", Devlin et al. 2018 //! Bert is a general large language model that can be used for various language tasks:
//! - [Arxiv](https://arxiv.org/abs/1810.04805) //! - Compute sentence embeddings for a prompt.
//! - [Github](https://github.com/google-research/bert) //! - Compute similarities between a set of sentences.
//! - [Arxiv](https://arxiv.org/abs/1810.04805) "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
//! - Upstream [Github repo](https://github.com/google-research/bert).
//! - See bert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
//!
//! ```no_run
//! // for sentence embeddings
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let model = todo!();
//! # let prompt = "Here is a test sentence";
//! let embeddings = model.forward(prompt)?;
//! // Returns tensor of shape [1, 7, 384]
//! println!("{embeddings}");
//! # Ok(())
//! # }
//!
//! // Different models can be loaded using the model ID
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let vb = todo!();
//! # let config = todo!();
//! let model = BertModel::load(vb, &config )?;
//! # Ok(())
//! # }
//!
//! // Gelu approximation
//! // You can get a speedup by configuring the model
//! // to use an approximation of the gelu activation:
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let mut config = todo!();
//! config.hidden_act = HiddenAct::GeluApproximate;
//! # Ok(())
//! # }
//!
//! // Similarities
//! // Bert can compute sentence embeddings which can then be used to calculate
//! // semantic similarities between sentences through cosine similarity scoring.
//! // The sentence embeddings are computed using average pooling across all tokens.
//! # use candle_core::Tensor;
//! # use candle_nn::{VarBuilder, Module};
//! # fn main() -> candle_core::Result<()> {
//! # let model = todo!();
//! let sentence1 = "The new movie is awesome";
//! let sentence2 = "The new movie is so great";
//! let emb1 = model.forward(sentence1)?;
//! let emb2 = model.forward(sentence2)?;
//! # Ok(())
//! # }
//! ```
//! //!
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear}; use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Result, Tensor};

View File

@ -1,9 +1,25 @@
//! BigCode implementation in Rust based on the GPT-BigCode model. //! BigCode implementation in Rust based on the GPT-BigCode model.
//! //!
//! See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023 //! [StarCoder/BigCode](https://huggingface.co/bigcode/starcoderbase-1b) is a LLM
//! model specialized to code generation. The initial model was trained on 80
//! programming languages. See "StarCoder: A State-of-the-Art LLM for Code", Mukherjee et al. 2023
//! - [Arxiv](https://arxiv.org/abs/2305.06161) //! - [Arxiv](https://arxiv.org/abs/2305.06161)
//! - [Github](https://github.com/bigcode-project/starcoder) //! - [Github](https://github.com/bigcode-project/starcoder)
//! //!
//! ## Running some example
//!
//! ```bash
//! cargo run --example bigcode --release -- --prompt "fn fact(n: u64) -> u64"
//!
//! > fn fact(n: u64) -> u64 {
//! > if n == 0 {
//! > 1
//! > } else {
//! > n * fact(n - 1)
//! > }
//! > }
//! ```
//!
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};

View File

@ -1,10 +1,26 @@
//! Flux Model //! Flux Model
//! //!
//! Flux is a series of text-to-image generation models based on diffusion transformers. //! Flux is a 12B rectified flow transformer capable of generating images from text descriptions.
//! //!
//! - [GH Link](https://github.com/black-forest-labs/flux) //! - [Hugging Face Model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py) //! - [GitHub Repository](https://github.com/black-forest-labs/flux)
//! - [Blog Post](https://blackforestlabs.ai/announcing-black-forest-labs/)
//! //!
//! # Usage
//!
//! ```bash
//! cargo run --features cuda \
//! --example flux -r -- \
//! --height 1024 --width 1024 \
//! --prompt "a rusty robot walking on a beach holding a small torch, \
//! the robot has the word \"rust\" written on it, high quality, 4k"
//! ```
//!
//! <div align=center>
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/flux/assets/flux-robot.jpg" alt="" width=320>
//! </div>
//!
use candle::{Result, Tensor}; use candle::{Result, Tensor};
pub trait WithForward { pub trait WithForward {

View File

@ -4,7 +4,38 @@
//! using images paired with text descriptions. //! using images paired with text descriptions.
//! //!
//! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral) //! - Transformers Python [reference implementation](https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral)
//! - [Blog Post](https://mistral.ai/news/pixtral-12b/) -
//! - [HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) -
//! - [HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b).
//! //!
//! # Example
//!
//! <div align=center>
//! <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/flux/assets/flux-robot.jpg" alt="" width=320>
//! </div>
//!
//! ```bash
//! cargo run --profile=release-with-debug \
//! --features cuda \
//! --example pixtral -- \
//! --image candle-examples/examples/flux/assets/flux-robot.jpg
//! ```
//!
//! ```txt
//! Describe the image.
//!
//! The image depicts a charming, rustic robot standing on a sandy beach at sunset.
//! The robot has a vintage, steampunk aesthetic with visible gears and mechanical
//! parts. It is holding a small lantern in one hand, which emits a warm glow, and
//! its other arm is extended forward as if reaching out or guiding the way. The
//! robot's body is adorned with the word "RUST" in bright orange letters, adding to
//! its rustic theme.
//!
//! The background features a dramatic sky filled with clouds, illuminated by the
//! setting sun, casting a golden hue over the scene. Gentle waves lap against the
//! shore, creating a serene and picturesque atmosphere. The overall mood of the
//! image is whimsical and nostalgic, evoking a sense of adventure and tranquility.
//! ```
pub mod llava; pub mod llava;
pub mod vision_model; pub mod vision_model;