mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add support for the marian base model. (#1221)
This commit is contained in:
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Error as E;
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -13,6 +13,12 @@ use candle_transformers::models::marian;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||
enum Which {
|
||||
Base,
|
||||
Big,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
@ -25,6 +31,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tokenizer_dec: Option<String>,
|
||||
|
||||
/// Choose the variant of the model to run.
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
@ -42,13 +52,22 @@ pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = marian::Config::opus_mt_tc_big_fr_en();
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => Api::new()?
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get("tokenizer-marian-fr.json")?,
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
@ -56,9 +75,15 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let tokenizer_dec = {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => Api::new()?
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get("tokenizer-marian-en.json")?,
|
||||
.get(name)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
};
|
||||
@ -67,9 +92,18 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
|
@ -13,6 +13,7 @@ pub enum Activation {
|
||||
Relu6,
|
||||
Silu,
|
||||
Sigmoid,
|
||||
Swish,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
}
|
||||
@ -28,6 +29,7 @@ impl super::Module for Activation {
|
||||
Self::Relu6 => xs.clamp(0f32, 6f32),
|
||||
Self::Silu => crate::ops::silu(xs),
|
||||
Self::Sigmoid => crate::ops::sigmoid(xs),
|
||||
Self::Swish => xs * crate::ops::sigmoid(xs)?,
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
}
|
||||
|
@ -51,6 +51,31 @@ impl Config {
|
||||
vocab_size: 53017,
|
||||
}
|
||||
}
|
||||
|
||||
// https://huggingface.co/Helsinki-NLP/opus-mt-fr-en/blob/main/config.json
|
||||
pub fn opus_mt_fr_en() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 59513,
|
||||
decoder_vocab_size: Some(59514),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 59513,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 59514,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
Reference in New Issue
Block a user