mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant * Unified stella * WIP: Unified Stella * Combined stella for both 1.5B and 400M variants * Cargo fmt for the CI * removed redundant stella-400m model and example after merge into stella-en-v5 * cargo fmt --all --------- Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling
|
||||
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
|
||||
|
||||
```bash
|
||||
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
|
||||
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b
|
||||
|
||||
>
|
||||
> Score: 0.8178786
|
||||
@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda>
|
||||
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
|
||||
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
|
||||
>
|
||||
|
||||
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m
|
||||
|
||||
>
|
||||
> Score: 0.8397539
|
||||
> Query: What are some ways to reduce stress?
|
||||
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
|
||||
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
|
||||
> stress from building up.
|
||||
>
|
||||
>
|
||||
>
|
||||
> Score: 0.809545
|
||||
> Query: What are the benefits of drinking green tea?
|
||||
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
|
||||
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
|
||||
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
|
||||
>
|
||||
```
|
||||
|
||||
## Supported options:
|
||||
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
|
||||
- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.
|
||||
|
||||
- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
|
||||
|
||||
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.
|
@ -212,6 +212,14 @@ impl EncodeTask {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "1.5b")]
|
||||
Large,
|
||||
#[value(name = "400m")]
|
||||
Small,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -219,6 +227,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
which: Which,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
@ -250,24 +261,33 @@ struct Args {
|
||||
|
||||
// Tokenizer creation is super critical in our case.
|
||||
// We are going to be `padding: Left` for each batch
|
||||
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
|
||||
fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
|
||||
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||
pad_id
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Tokenizer doesn't contain expected `<|endoftext|>` token"
|
||||
));
|
||||
};
|
||||
|
||||
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
|
||||
tokenizer.with_padding(Some(PaddingParams {
|
||||
strategy: PaddingStrategy::BatchLongest,
|
||||
direction: PaddingDirection::Left,
|
||||
pad_id,
|
||||
pad_token: "<|endoftext|>".to_string(),
|
||||
..Default::default()
|
||||
}));
|
||||
if which == Which::Large {
|
||||
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||
pad_id
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Tokenizer doesn't contain expected `<|endoftext|>` token"
|
||||
));
|
||||
};
|
||||
|
||||
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
|
||||
tokenizer.with_padding(Some(PaddingParams {
|
||||
strategy: PaddingStrategy::BatchLongest,
|
||||
direction: PaddingDirection::Left,
|
||||
pad_id,
|
||||
pad_token: "<|endoftext|>".to_string(),
|
||||
..Default::default()
|
||||
}));
|
||||
} else {
|
||||
tokenizer.with_padding(Some(PaddingParams {
|
||||
strategy: PaddingStrategy::BatchLongest,
|
||||
direction: PaddingDirection::Right,
|
||||
..Default::default()
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(tokenizer)
|
||||
}
|
||||
@ -298,7 +318,19 @@ fn main() -> Result<()> {
|
||||
Some(d) => d,
|
||||
None => EmbedDim::Dim1024,
|
||||
};
|
||||
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
|
||||
|
||||
let (repo, cfg) = match args.which {
|
||||
Which::Large => (
|
||||
"dunzhang/stella_en_1.5B_v5",
|
||||
Config::new_1_5_b_v5(embed_dim.embed_dim()),
|
||||
),
|
||||
Which::Small => (
|
||||
"dunzhang/stella_en_400M_v5",
|
||||
Config::new_400_m_v5(embed_dim.embed_dim()),
|
||||
),
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::model(repo.to_string()));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
@ -330,7 +362,7 @@ fn main() -> Result<()> {
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
|
||||
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
|
||||
let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
@ -343,11 +375,7 @@ fn main() -> Result<()> {
|
||||
let embed_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
|
||||
|
||||
let model = EmbeddingModel::new(
|
||||
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
|
||||
base_vb,
|
||||
embed_vb,
|
||||
)?;
|
||||
let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user