mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Added new language pairs to marian-mt example. (#2860)
* added new language pairs to marian-mt * lint * seperated python code for converting tokenizers into its own file and and added a reqirements.txt for dependencies, updated instructions in readme and included python version * Cleanup. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -20,6 +20,22 @@ enum Which {
|
||||
Big,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum LanguagePair {
|
||||
#[value(name = "fr-en")]
|
||||
FrEn,
|
||||
#[value(name = "en-zh")]
|
||||
EnZh,
|
||||
#[value(name = "en-hi")]
|
||||
EnHi,
|
||||
#[value(name = "en-es")]
|
||||
EnEs,
|
||||
#[value(name = "en-fr")]
|
||||
EnFr,
|
||||
#[value(name = "en-ru")]
|
||||
EnRu,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
@ -36,6 +52,10 @@ struct Args {
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
// Choose which language pair to use
|
||||
#[arg(long, default_value = "fr-en")]
|
||||
language_pair: LanguagePair,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
let config = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
|
||||
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
|
||||
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
|
||||
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
|
||||
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
|
||||
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
|
||||
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
|
||||
};
|
||||
let tokenizer_default_repo = match args.language_pair {
|
||||
LanguagePair::FrEn => "lmz/candle-marian",
|
||||
LanguagePair::EnZh
|
||||
| LanguagePair::EnHi
|
||||
| LanguagePair::EnEs
|
||||
| LanguagePair::EnFr
|
||||
| LanguagePair::EnRu => "KeighBee/candle-marian",
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
let filename = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
|
||||
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
|
||||
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
|
||||
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
|
||||
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
|
||||
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
|
||||
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
|
||||
(Which::Big, lp) => {
|
||||
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||
}
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
.model(tokenizer_default_repo.to_string())
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
let filename = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
|
||||
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
|
||||
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
|
||||
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
|
||||
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
|
||||
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
|
||||
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
|
||||
(Which::Big, lp) => {
|
||||
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||
}
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
.model(tokenizer_default_repo.to_string())
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let api = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
)),
|
||||
(Which::Big, LanguagePair::FrEn) => {
|
||||
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
}
|
||||
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-zh".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/13".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-hi".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-es".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-fr".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/9".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-ru".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/7".to_string(),
|
||||
)),
|
||||
(Which::Big, lp) => {
|
||||
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||
}
|
||||
};
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
|
Reference in New Issue
Block a user