mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example. * Use the secondary decoder. * Add a readme. * More readme.
This commit is contained in:
@ -103,6 +103,8 @@ We also provide a some command line based examples using state of the art models
|
|||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
|
model, generates the translated text from the input text.
|
||||||
|
|
||||||
Run them using commands like:
|
Run them using commands like:
|
||||||
```
|
```
|
||||||
@ -174,6 +176,8 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Wurstchen v2.
|
- Wurstchen v2.
|
||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
|
- Text to text.
|
||||||
|
- Marian MT (Machine Translation).
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
|
19
candle-examples/examples/marian-mt/README.md
Normal file
19
candle-examples/examples/marian-mt/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# candle-marian-mt
|
||||||
|
|
||||||
|
`marian-mt` is a neural machine translation model. In this example it is used to
|
||||||
|
translate text from French to English. See the associated [model
|
||||||
|
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
|
||||||
|
the model itself.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example marian-mt --release -- \
|
||||||
|
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
|
||||||
|
I know you are waiting for me. I will go through the forest, I will go through the
|
||||||
|
mountain. I cannot stay far from you any longer.</s>
|
||||||
|
```
|
@ -8,7 +8,6 @@ use anyhow::Error as E;
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::marian;
|
use candle_transformers::models::marian;
|
||||||
|
|
||||||
@ -18,10 +17,13 @@ use tokenizers::Tokenizer;
|
|||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: String,
|
model: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer: String,
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_dec: Option<String>,
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -37,25 +39,52 @@ struct Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config = marian::Config::opus_mt_tc_big_fr_en();
|
let config = marian::Config::opus_mt_tc_big_fr_en();
|
||||||
|
let tokenizer = {
|
||||||
|
let tokenizer = match args.tokenizer {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("lmz/candle-marian".to_string())
|
||||||
|
.get("tokenizer-marian-fr.json")?,
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer_dec = {
|
||||||
|
let tokenizer = match args.tokenizer_dec {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("lmz/candle-marian".to_string())
|
||||||
|
.get("tokenizer-marian-en.json")?,
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
|
};
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
|
let vb = {
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
|
};
|
||||||
let model = marian::MTModel::new(&config, vb)?;
|
let model = marian::MTModel::new(&config, vb)?;
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
|
|
||||||
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
|
|
||||||
let mut logits_processor =
|
let mut logits_processor =
|
||||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||||
|
|
||||||
let encoder_xs = {
|
let encoder_xs = {
|
||||||
let tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(args.text, true)
|
.encode(args.text, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
tokens.push(config.eos_token_id);
|
||||||
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
model.encoder().forward(&tokens, 0)?
|
model.encoder().forward(&tokens, 0)?
|
||||||
};
|
};
|
||||||
@ -70,20 +99,15 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||||
let token = logits_processor.sample(&logits)?;
|
let token = logits_processor.sample(&logits)?;
|
||||||
|
token_ids.push(token);
|
||||||
println!("{token}");
|
println!("{token}");
|
||||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
token_ids.push(token);
|
|
||||||
if let Some(t) = tokenizer_dec.next_token(token)? {
|
|
||||||
use std::io::Write;
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
|
println!(
|
||||||
print!("{rest}");
|
"{}",
|
||||||
}
|
tokenizer_dec.decode(&token_ids, true).map_err(E::msg)?
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,12 @@ impl Attention {
|
|||||||
.contiguous()
|
.contiguous()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
kv_states: Option<&Tensor>,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let is_cross_attn = kv_states.is_some();
|
let is_cross_attn = kv_states.is_some();
|
||||||
let (b_sz, tgt_len, _) = xs.dims3()?;
|
let (b_sz, tgt_len, _) = xs.dims3()?;
|
||||||
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
|
||||||
@ -156,7 +161,10 @@ impl Attention {
|
|||||||
let key_states = key_states.reshape(proj_shape)?;
|
let key_states = key_states.reshape(proj_shape)?;
|
||||||
let value_states = value_states.reshape(proj_shape)?;
|
let value_states = value_states.reshape(proj_shape)?;
|
||||||
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
|
||||||
// todo: attn_mask
|
let attn_weights = match attn_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
|
||||||
|
};
|
||||||
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
let attn_output = attn_probs.matmul(&value_states)?;
|
let attn_output = attn_probs.matmul(&value_states)?;
|
||||||
attn_output
|
attn_output
|
||||||
@ -196,8 +204,8 @@ impl EncoderLayer {
|
|||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs =
|
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
|
||||||
(self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
|
.apply(&self.self_attn_layer_norm)?;
|
||||||
let residual = &xs;
|
let residual = &xs;
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&self.fc1)?
|
.apply(&self.fc1)?
|
||||||
@ -241,15 +249,20 @@ impl DecoderLayer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
encoder_xs: Option<&Tensor>,
|
||||||
|
attn_mask: &Tensor,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs =
|
let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
|
||||||
(self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?;
|
.apply(&self.self_attn_layer_norm)?;
|
||||||
let xs = match encoder_xs {
|
let xs = match encoder_xs {
|
||||||
None => xs,
|
None => xs,
|
||||||
Some(encoder_xs) => {
|
Some(encoder_xs) => {
|
||||||
let residual = &xs;
|
let residual = &xs;
|
||||||
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
|
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
|
||||||
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
|
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -346,6 +359,7 @@ impl Decoder {
|
|||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
encoder_xs: Option<&Tensor>,
|
encoder_xs: Option<&Tensor>,
|
||||||
past_kv_len: usize,
|
past_kv_len: usize,
|
||||||
|
attn_mask: &Tensor,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let xs = xs.apply(&self.embed_tokens)?;
|
let xs = xs.apply(&self.embed_tokens)?;
|
||||||
let xs = match self.embed_scale {
|
let xs = match self.embed_scale {
|
||||||
@ -358,7 +372,7 @@ impl Decoder {
|
|||||||
.unsqueeze(0)?;
|
.unsqueeze(0)?;
|
||||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter() {
|
||||||
xs = layer.forward(&xs, encoder_xs)?;
|
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
|
||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
@ -413,9 +427,14 @@ impl MTModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let seq_len = xs.dim(1)?;
|
||||||
|
let mask: Vec<_> = (0..seq_len)
|
||||||
|
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
|
||||||
self.model
|
self.model
|
||||||
.decoder
|
.decoder
|
||||||
.forward(xs, Some(encoder_xs), 0)?
|
.forward(xs, Some(encoder_xs), 0, &mask)?
|
||||||
.apply(&self.lm_head)?
|
.apply(&self.lm_head)?
|
||||||
.broadcast_add(&self.final_logits_bias)
|
.broadcast_add(&self.final_logits_bias)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user