T5 tweaks (#831)

* Use default values rather than options.

* Avoid exposing the device field.

* More tweaks.
This commit is contained in:
Laurent Mazare
2023-09-13 08:37:04 +02:00
committed by GitHub
parent d801e1d564
commit e4553fb355
2 changed files with 35 additions and 32 deletions

View File

@ -30,7 +30,7 @@ struct Args {
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
/// The model repository to use on the HuggingFace hub.
#[arg(long)]
model_id: Option<String>,
@ -94,22 +94,10 @@ impl Args {
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let prompt = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let tokenizer = tokenizer
.with_padding(None)
@ -120,7 +108,7 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let token_ids = Tensor::new(&tokens[..], model.device())?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();