From 392a00a147c26ebe70c6484d72223d02ada6a72a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 30 Oct 2023 20:20:36 +0100 Subject: [PATCH] Add support for the marian base model. (#1221) --- candle-examples/examples/marian-mt/main.rs | 56 +++++++++++++++++----- candle-nn/src/activation.rs | 2 + candle-transformers/src/models/marian.rs | 25 ++++++++++ 3 files changed, 72 insertions(+), 11 deletions(-) diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs index c198777c..c503667c 100644 --- a/candle-examples/examples/marian-mt/main.rs +++ b/candle-examples/examples/marian-mt/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Error as E; -use clap::Parser; +use clap::{Parser, ValueEnum}; use candle::{DType, Tensor}; use candle_nn::VarBuilder; @@ -13,6 +13,12 @@ use candle_transformers::models::marian; use tokenizers::Tokenizer; +#[derive(Clone, Debug, Copy, ValueEnum)] +enum Which { + Base, + Big, +} + // TODO: Maybe add support for the conditional prompt. #[derive(Parser)] struct Args { @@ -25,6 +31,10 @@ struct Args { #[arg(long)] tokenizer_dec: Option, + /// Choose the variant of the model to run. + #[arg(long, default_value = "big")] + which: Which, + /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, @@ -42,13 +52,22 @@ pub fn main() -> anyhow::Result<()> { use hf_hub::api::sync::Api; let args = Args::parse(); - let config = marian::Config::opus_mt_tc_big_fr_en(); + let config = match args.which { + Which::Base => marian::Config::opus_mt_fr_en(), + Which::Big => marian::Config::opus_mt_tc_big_fr_en(), + }; let tokenizer = { let tokenizer = match args.tokenizer { Some(tokenizer) => std::path::PathBuf::from(tokenizer), - None => Api::new()? - .model("lmz/candle-marian".to_string()) - .get("tokenizer-marian-fr.json")?, + None => { + let name = match args.which { + Which::Base => "tokenizer-marian-base-fr.json", + Which::Big => "tokenizer-marian-fr.json", + }; + Api::new()? + .model("lmz/candle-marian".to_string()) + .get(name)? + } }; Tokenizer::from_file(&tokenizer).map_err(E::msg)? }; @@ -56,9 +75,15 @@ pub fn main() -> anyhow::Result<()> { let tokenizer_dec = { let tokenizer = match args.tokenizer_dec { Some(tokenizer) => std::path::PathBuf::from(tokenizer), - None => Api::new()? - .model("lmz/candle-marian".to_string()) - .get("tokenizer-marian-en.json")?, + None => { + let name = match args.which { + Which::Base => "tokenizer-marian-base-en.json", + Which::Big => "tokenizer-marian-en.json", + }; + Api::new()? + .model("lmz/candle-marian".to_string()) + .get(name)? + } }; Tokenizer::from_file(&tokenizer).map_err(E::msg)? }; @@ -67,9 +92,18 @@ pub fn main() -> anyhow::Result<()> { let vb = { let model = match args.model { Some(model) => std::path::PathBuf::from(model), - None => Api::new()? - .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string()) - .get("model.safetensors")?, + None => match args.which { + Which::Base => Api::new()? + .repo(hf_hub::Repo::with_revision( + "Helsinki-NLP/opus-mt-fr-en".to_string(), + hf_hub::RepoType::Model, + "refs/pr/4".to_string(), + )) + .get("model.safetensors")?, + Which::Big => Api::new()? + .model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string()) + .get("model.safetensors")?, + }, }; unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? } }; diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 52ceba78..79cf9c82 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -13,6 +13,7 @@ pub enum Activation { Relu6, Silu, Sigmoid, + Swish, Elu(f64), LeakyRelu(f64), } @@ -28,6 +29,7 @@ impl super::Module for Activation { Self::Relu6 => xs.clamp(0f32, 6f32), Self::Silu => crate::ops::silu(xs), Self::Sigmoid => crate::ops::sigmoid(xs), + Self::Swish => xs * crate::ops::sigmoid(xs)?, &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), } diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index 2bcfd2f7..5305d4d8 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -51,6 +51,31 @@ impl Config { vocab_size: 53017, } } + + // https://huggingface.co/Helsinki-NLP/opus-mt-fr-en/blob/main/config.json + pub fn opus_mt_fr_en() -> Self { + Self { + activation_function: candle_nn::Activation::Swish, + d_model: 512, + decoder_attention_heads: 8, + decoder_ffn_dim: 2048, + decoder_layers: 6, + decoder_start_token_id: 59513, + decoder_vocab_size: Some(59514), + encoder_attention_heads: 8, + encoder_ffn_dim: 2048, + encoder_layers: 6, + eos_token_id: 0, + forced_eos_token_id: 0, + is_encoder_decoder: true, + max_position_embeddings: 512, + pad_token_id: 59513, + scale_embedding: true, + share_encoder_decoder_embeddings: true, + use_cache: true, + vocab_size: 59514, + } + } } #[derive(Debug, Clone)]