Files
candle/candle-examples/examples/distilbert/README.md
Kyle Birnbaum eb478ece92 Implementing DistilBertForMaskedLM. (#2866)
* Initial commit: model weights working, prediciton incorrect

* moved distilbertformaskedlm into distilbert modeling file

* made maskedLM like bert example, still incorrect predictions

* finally not getting NaNs, fixed attention mask

* getting correct output sentences

* get top k predictions

* fixed output formatting slightly

* added default arg for model_id

* lint

* moved masked token example code from distilbertformaskedlm example to distilbert example

* lint

* removed distilbertformaskedlm example

* cleanup

* clippy

* removed embedding normalization from example

* made output and model dependent on args instead of prompt

* lint

* replaced or_ok anyhow error with anyhow context

* changed error message for mask token not found
2025-04-11 13:25:39 +02:00

1.4 KiB

candle-distilbert

DistilBert is a distiled version of the Bert model.

Sentence embeddings

DistilBert is used to compute the sentence embeddings for a prompt. The model weights are downloaded from the hub on the first run.

$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"

> [[[ 0.5109,  0.1280, -0.2635, ...,  0.3462, -1.0434,  0.1441],
>   [ 0.1735,  0.0818, -0.5549, ...,  0.3472, -0.8264, -0.0244],
>   [ 0.0702, -0.1311, -0.4914, ...,  0.3483, -0.6194,  0.1829],
>   ...
>   [ 0.2993, -0.0106, -0.4640, ...,  0.2844, -0.6732,  0.0042],
>   [ 0.1066, -0.0081, -0.4299, ...,  0.3435, -0.7729,  0.0190],
>   [ 0.8903,  0.2055, -0.2541, ...,  0.3208, -0.6585,  0.0586]]]
> Tensor[[1, 7, 768], f32]

Masked Token

DistilBert is used to compute the top K choices for a masked token.

$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10

> Input: The capital of France is [MASK].
> Predictions for [MASK] at position 6:
>   1: marseille       (probability: 12.14%)
>   2: paris           (probability: 10.84%)
>   3: toulouse        (probability: 8.57%)
>   4: lyon            (probability: 7.61%)
>   5: montpellier     (probability: 5.18%)
>   6: bordeaux        (probability: 4.88%)
>   7: nantes          (probability: 4.82%)
>   8: lille           (probability: 4.07%)
>   9: strasbourg      (probability: 3.12%)
>   10: cannes          (probability: 3.04%)