mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
396 lines
16 KiB
HTML
396 lines
16 KiB
HTML
<!DOCTYPE HTML>
|
|
<html lang="en" class="light sidebar-visible" dir="ltr">
|
|
<head>
|
|
<!-- Book generated using mdBook -->
|
|
<meta charset="UTF-8">
|
|
<title>Simplified - Candle Documentation</title>
|
|
|
|
|
|
<!-- Custom HTML head -->
|
|
|
|
<meta name="description" content="">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
<meta name="theme-color" content="#ffffff">
|
|
|
|
<link rel="icon" href="../favicon.svg">
|
|
<link rel="shortcut icon" href="../favicon.png">
|
|
<link rel="stylesheet" href="../css/variables.css">
|
|
<link rel="stylesheet" href="../css/general.css">
|
|
<link rel="stylesheet" href="../css/chrome.css">
|
|
<link rel="stylesheet" href="../css/print.css" media="print">
|
|
|
|
<!-- Fonts -->
|
|
<link rel="stylesheet" href="../FontAwesome/css/font-awesome.css">
|
|
<link rel="stylesheet" href="../fonts/fonts.css">
|
|
|
|
<!-- Highlight.js Stylesheets -->
|
|
<link rel="stylesheet" id="highlight-css" href="../highlight.css">
|
|
<link rel="stylesheet" id="tomorrow-night-css" href="../tomorrow-night.css">
|
|
<link rel="stylesheet" id="ayu-highlight-css" href="../ayu-highlight.css">
|
|
|
|
<!-- Custom theme stylesheets -->
|
|
|
|
|
|
<!-- Provide site root and default themes to javascript -->
|
|
<script>
|
|
const path_to_root = "../";
|
|
const default_light_theme = "light";
|
|
const default_dark_theme = "navy";
|
|
</script>
|
|
<!-- Start loading toc.js asap -->
|
|
<script src="../toc.js"></script>
|
|
</head>
|
|
<body>
|
|
<div id="body-container">
|
|
<!-- Work around some values being stored in localStorage wrapped in quotes -->
|
|
<script>
|
|
try {
|
|
let theme = localStorage.getItem('mdbook-theme');
|
|
let sidebar = localStorage.getItem('mdbook-sidebar');
|
|
|
|
if (theme.startsWith('"') && theme.endsWith('"')) {
|
|
localStorage.setItem('mdbook-theme', theme.slice(1, theme.length - 1));
|
|
}
|
|
|
|
if (sidebar.startsWith('"') && sidebar.endsWith('"')) {
|
|
localStorage.setItem('mdbook-sidebar', sidebar.slice(1, sidebar.length - 1));
|
|
}
|
|
} catch (e) { }
|
|
</script>
|
|
|
|
<!-- Set the theme before any content is loaded, prevents flash -->
|
|
<script>
|
|
const default_theme = window.matchMedia("(prefers-color-scheme: dark)").matches ? default_dark_theme : default_light_theme;
|
|
let theme;
|
|
try { theme = localStorage.getItem('mdbook-theme'); } catch(e) { }
|
|
if (theme === null || theme === undefined) { theme = default_theme; }
|
|
const html = document.documentElement;
|
|
html.classList.remove('light')
|
|
html.classList.add(theme);
|
|
html.classList.add("js");
|
|
</script>
|
|
|
|
<input type="checkbox" id="sidebar-toggle-anchor" class="hidden">
|
|
|
|
<!-- Hide / unhide sidebar before it is displayed -->
|
|
<script>
|
|
let sidebar = null;
|
|
const sidebar_toggle = document.getElementById("sidebar-toggle-anchor");
|
|
if (document.body.clientWidth >= 1080) {
|
|
try { sidebar = localStorage.getItem('mdbook-sidebar'); } catch(e) { }
|
|
sidebar = sidebar || 'visible';
|
|
} else {
|
|
sidebar = 'hidden';
|
|
}
|
|
sidebar_toggle.checked = sidebar === 'visible';
|
|
html.classList.remove('sidebar-visible');
|
|
html.classList.add("sidebar-" + sidebar);
|
|
</script>
|
|
|
|
<nav id="sidebar" class="sidebar" aria-label="Table of contents">
|
|
<!-- populated by js -->
|
|
<mdbook-sidebar-scrollbox class="sidebar-scrollbox"></mdbook-sidebar-scrollbox>
|
|
<noscript>
|
|
<iframe class="sidebar-iframe-outer" src="../toc.html"></iframe>
|
|
</noscript>
|
|
<div id="sidebar-resize-handle" class="sidebar-resize-handle">
|
|
<div class="sidebar-resize-indicator"></div>
|
|
</div>
|
|
</nav>
|
|
|
|
<div id="page-wrapper" class="page-wrapper">
|
|
|
|
<div class="page">
|
|
<div id="menu-bar-hover-placeholder"></div>
|
|
<div id="menu-bar" class="menu-bar sticky">
|
|
<div class="left-buttons">
|
|
<label id="sidebar-toggle" class="icon-button" for="sidebar-toggle-anchor" title="Toggle Table of Contents" aria-label="Toggle Table of Contents" aria-controls="sidebar">
|
|
<i class="fa fa-bars"></i>
|
|
</label>
|
|
<button id="theme-toggle" class="icon-button" type="button" title="Change theme" aria-label="Change theme" aria-haspopup="true" aria-expanded="false" aria-controls="theme-list">
|
|
<i class="fa fa-paint-brush"></i>
|
|
</button>
|
|
<ul id="theme-list" class="theme-popup" aria-label="Themes" role="menu">
|
|
<li role="none"><button role="menuitem" class="theme" id="default_theme">Auto</button></li>
|
|
<li role="none"><button role="menuitem" class="theme" id="light">Light</button></li>
|
|
<li role="none"><button role="menuitem" class="theme" id="rust">Rust</button></li>
|
|
<li role="none"><button role="menuitem" class="theme" id="coal">Coal</button></li>
|
|
<li role="none"><button role="menuitem" class="theme" id="navy">Navy</button></li>
|
|
<li role="none"><button role="menuitem" class="theme" id="ayu">Ayu</button></li>
|
|
</ul>
|
|
<button id="search-toggle" class="icon-button" type="button" title="Search. (Shortkey: s)" aria-label="Toggle Searchbar" aria-expanded="false" aria-keyshortcuts="S" aria-controls="searchbar">
|
|
<i class="fa fa-search"></i>
|
|
</button>
|
|
</div>
|
|
|
|
<h1 class="menu-title">Candle Documentation</h1>
|
|
|
|
<div class="right-buttons">
|
|
<a href="../print.html" title="Print this book" aria-label="Print this book">
|
|
<i id="print-button" class="fa fa-print"></i>
|
|
</a>
|
|
|
|
</div>
|
|
</div>
|
|
|
|
<div id="search-wrapper" class="hidden">
|
|
<form id="searchbar-outer" class="searchbar-outer">
|
|
<input type="search" id="searchbar" name="searchbar" placeholder="Search this book ..." aria-controls="searchresults-outer" aria-describedby="searchresults-header">
|
|
</form>
|
|
<div id="searchresults-outer" class="searchresults-outer hidden">
|
|
<div id="searchresults-header" class="searchresults-header"></div>
|
|
<ul id="searchresults">
|
|
</ul>
|
|
</div>
|
|
</div>
|
|
|
|
<!-- Apply ARIA attributes after the sidebar and the sidebar toggle button are added to the DOM -->
|
|
<script>
|
|
document.getElementById('sidebar-toggle').setAttribute('aria-expanded', sidebar === 'visible');
|
|
document.getElementById('sidebar').setAttribute('aria-hidden', sidebar !== 'visible');
|
|
Array.from(document.querySelectorAll('#sidebar a')).forEach(function(link) {
|
|
link.setAttribute('tabIndex', sidebar === 'visible' ? 0 : -1);
|
|
});
|
|
</script>
|
|
|
|
<div id="content" class="content">
|
|
<main>
|
|
<h1 id="simplified"><a class="header" href="#simplified">Simplified</a></h1>
|
|
<h2 id="how-its-works"><a class="header" href="#how-its-works">How its works</a></h2>
|
|
<p>This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.</p>
|
|
<p>Basic moments:</p>
|
|
<ol>
|
|
<li>A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.</li>
|
|
<li>The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.</li>
|
|
<li>The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.</li>
|
|
<li>For training, samples with real data on the results of the first and second stages of different elections are used.</li>
|
|
<li>The model is trained by backpropagation using gradient descent and the cross-entropy loss function.</li>
|
|
<li>Model parameters (weights of neurons) are initialized randomly, then optimized during training.</li>
|
|
<li>After training, the model is tested on a deferred sample to evaluate the accuracy.</li>
|
|
<li>If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.</li>
|
|
</ol>
|
|
<p>Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.</p>
|
|
<pre><code class="language-rust ignore">const VOTE_DIM: usize = 2;
|
|
const RESULTS: usize = 1;
|
|
const EPOCHS: usize = 10;
|
|
const LAYER1_OUT_SIZE: usize = 4;
|
|
const LAYER2_OUT_SIZE: usize = 2;
|
|
const LEARNING_RATE: f64 = 0.05;
|
|
|
|
#[derive(Clone)]
|
|
pub struct Dataset {
|
|
pub train_votes: Tensor,
|
|
pub train_results: Tensor,
|
|
pub test_votes: Tensor,
|
|
pub test_results: Tensor,
|
|
}
|
|
|
|
struct MultiLevelPerceptron {
|
|
ln1: Linear,
|
|
ln2: Linear,
|
|
ln3: Linear,
|
|
}
|
|
|
|
impl MultiLevelPerceptron {
|
|
fn new(vs: VarBuilder) -> Result<Self> {
|
|
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
|
|
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
|
|
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
|
|
Ok(Self { ln1, ln2, ln3 })
|
|
}
|
|
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
let xs = self.ln1.forward(xs)?;
|
|
let xs = xs.relu()?;
|
|
let xs = self.ln2.forward(&xs)?;
|
|
let xs = xs.relu()?;
|
|
self.ln3.forward(&xs)
|
|
}
|
|
}
|
|
</code></pre>
|
|
<pre><code class="language-rust ignore">fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
|
|
let train_results = m.train_results.to_device(dev)?;
|
|
let train_votes = m.train_votes.to_device(dev)?;
|
|
let varmap = VarMap::new();
|
|
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
|
|
let model = MultiLevelPerceptron::new(vs.clone())?;
|
|
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
|
|
let test_votes = m.test_votes.to_device(dev)?;
|
|
let test_results = m.test_results.to_device(dev)?;
|
|
let mut final_accuracy: f32 = 0.0;
|
|
for epoch in 1..EPOCHS + 1 {
|
|
let logits = model.forward(&train_votes)?;
|
|
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
|
|
let loss = loss::nll(&log_sm, &train_results)?;
|
|
sgd.backward_step(&loss)?;
|
|
|
|
let test_logits = model.forward(&test_votes)?;
|
|
let sum_ok = test_logits
|
|
.argmax(D::Minus1)?
|
|
.eq(&test_results)?
|
|
.to_dtype(DType::F32)?
|
|
.sum_all()?
|
|
.to_scalar::<f32>()?;
|
|
let test_accuracy = sum_ok / test_results.dims1()? as f32;
|
|
final_accuracy = 100. * test_accuracy;
|
|
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
|
|
loss.to_scalar::<f32>()?,
|
|
final_accuracy
|
|
);
|
|
if final_accuracy == 100.0 {
|
|
break;
|
|
}
|
|
}
|
|
if final_accuracy < 100.0 {
|
|
Err(anyhow::Error::msg("The model is not trained well enough."))
|
|
} else {
|
|
Ok(model)
|
|
}
|
|
}</code></pre>
|
|
<pre><code class="language-rust ignore">#[tokio::test]
|
|
async fn simplified() -> anyhow::Result<()> {
|
|
|
|
let dev = Device::cuda_if_available(0)?;
|
|
|
|
let train_votes_vec: Vec<u32> = vec![
|
|
15, 10,
|
|
10, 15,
|
|
5, 12,
|
|
30, 20,
|
|
16, 12,
|
|
13, 25,
|
|
6, 14,
|
|
31, 21,
|
|
];
|
|
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
|
|
|
let train_results_vec: Vec<u32> = vec![
|
|
1,
|
|
0,
|
|
0,
|
|
1,
|
|
1,
|
|
0,
|
|
0,
|
|
1,
|
|
];
|
|
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
|
|
|
|
let test_votes_vec: Vec<u32> = vec![
|
|
13, 9,
|
|
8, 14,
|
|
3, 10,
|
|
];
|
|
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
|
|
|
let test_results_vec: Vec<u32> = vec![
|
|
1,
|
|
0,
|
|
0,
|
|
];
|
|
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
|
|
|
|
let m = Dataset {
|
|
train_votes: train_votes_tensor,
|
|
train_results: train_results_tensor,
|
|
test_votes: test_votes_tensor,
|
|
test_results: test_results_tensor,
|
|
};
|
|
|
|
let trained_model: MultiLevelPerceptron;
|
|
loop {
|
|
println!("Trying to train neural network.");
|
|
match train(m.clone(), &dev) {
|
|
Ok(model) => {
|
|
trained_model = model;
|
|
break;
|
|
},
|
|
Err(e) => {
|
|
println!("Error: {}", e);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
let real_world_votes: Vec<u32> = vec![
|
|
13, 22,
|
|
];
|
|
|
|
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
|
|
|
|
let final_result = trained_model.forward(&tensor_test_votes)?;
|
|
|
|
let result = final_result
|
|
.argmax(D::Minus1)?
|
|
.to_dtype(DType::F32)?
|
|
.get(0).map(|x| x.to_scalar::<f32>())??;
|
|
println!("real_life_votes: {:?}", real_world_votes);
|
|
println!("neural_network_prediction_result: {:?}", result);
|
|
|
|
Ok(())
|
|
|
|
}</code></pre>
|
|
<h2 id="example-output"><a class="header" href="#example-output">Example output</a></h2>
|
|
<pre><code class="language-bash">Trying to train neural network.
|
|
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
|
|
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
|
|
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
|
|
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
|
|
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
|
|
real_life_votes: [13, 22]
|
|
neural_network_prediction_result: 0.0
|
|
</code></pre>
|
|
|
|
</main>
|
|
|
|
<nav class="nav-wrapper" aria-label="Page navigation">
|
|
<!-- Mobile navigation buttons -->
|
|
<a rel="prev" href="../training/training.html" class="mobile-nav-chapters previous" title="Previous chapter" aria-label="Previous chapter" aria-keyshortcuts="Left">
|
|
<i class="fa fa-angle-left"></i>
|
|
</a>
|
|
|
|
<a rel="next prefetch" href="../training/mnist.html" class="mobile-nav-chapters next" title="Next chapter" aria-label="Next chapter" aria-keyshortcuts="Right">
|
|
<i class="fa fa-angle-right"></i>
|
|
</a>
|
|
|
|
<div style="clear: both"></div>
|
|
</nav>
|
|
</div>
|
|
</div>
|
|
|
|
<nav class="nav-wide-wrapper" aria-label="Page navigation">
|
|
<a rel="prev" href="../training/training.html" class="nav-chapters previous" title="Previous chapter" aria-label="Previous chapter" aria-keyshortcuts="Left">
|
|
<i class="fa fa-angle-left"></i>
|
|
</a>
|
|
|
|
<a rel="next prefetch" href="../training/mnist.html" class="nav-chapters next" title="Next chapter" aria-label="Next chapter" aria-keyshortcuts="Right">
|
|
<i class="fa fa-angle-right"></i>
|
|
</a>
|
|
</nav>
|
|
|
|
</div>
|
|
|
|
|
|
|
|
|
|
<script>
|
|
window.playground_copyable = true;
|
|
</script>
|
|
|
|
|
|
<script src="../elasticlunr.min.js"></script>
|
|
<script src="../mark.min.js"></script>
|
|
<script src="../searcher.js"></script>
|
|
|
|
<script src="../clipboard.min.js"></script>
|
|
<script src="../highlight.js"></script>
|
|
<script src="../book.js"></script>
|
|
|
|
<!-- Custom JS scripts -->
|
|
|
|
|
|
</div>
|
|
</body>
|
|
</html>
|