feat: add network policy

This commit is contained in:
2025-02-28 12:40:45 +00:00
parent 29488d75a3
commit 5e2088f09c
21 changed files with 700 additions and 155 deletions

3
Cargo.lock generated
View File

@ -1841,8 +1841,7 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "ipnetwork"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763"
source = "git+https://github.com/v0l/ipnetwork.git?rev=51b4816358f255ecbcff90b739fbbd4b1cbc2d6a#51b4816358f255ecbcff90b739fbbd4b1cbc2d6a"
[[package]]
name = "is-terminal"

View File

@ -29,7 +29,7 @@ nostr = { version = "0.39.0", default-features = false, features = ["std"] }
base64 = { version = "0.22.1", features = ["alloc"] }
urlencoding = "2.1.3"
fedimint-tonic-lnd = { version = "0.2.0", default-features = false, features = ["invoicesrpc"] }
ipnetwork = "0.21.1"
ipnetwork = { git = "https://github.com/v0l/ipnetwork.git", rev = "51b4816358f255ecbcff90b739fbbd4b1cbc2d6a" }
rand = "0.9.0"
clap = { version = "4.5.21", features = ["derive"] }
ssh2 = "0.9.4"

View File

@ -21,19 +21,19 @@ lnd:
macaroon: "$HOME/.lnd/data/chain/bitcoin/mainnet/admin.macaroon"
# Number of days after a VM expires to delete
delete_after: 3
delete-after: 3
# Provisioner is the main process which handles creating/deleting VM's
# Currently supports: Proxmox
provisioner:
proxmox:
# Read-only mode prevents spawning VM's
read_only: false
read-only: false
# Proxmox (QEMU) settings used for spawning VM's
qemu:
bios: "ovmf"
machine: "q35"
os_type: "l26"
os-type: "l26"
bridge: "vmbr0"
cpu: "kvm64"
vlan: 100
@ -89,6 +89,12 @@ router:
url: "https://my-router.net"
username: "admin"
password: "admin"
# Interface where the static ARP entry is added
arp_interface: "bridge1"
network-policy:
# How packets get to the VM
# (default "auto", nothing to do, packets will always arrive)
access:
# Static ARP entries are added to the router for each provisioned IP
static-arp:
# Interface where the static ARP entry is added
interface: "bridge1"
```

View File

@ -3,14 +3,14 @@ lnd:
url: "https://127.0.0.1:10003"
cert: "/home/kieran/.polar/networks/2/volumes/lnd/alice/tls.cert"
macaroon: "/home/kieran/.polar/networks/2/volumes/lnd/alice/data/chain/bitcoin/regtest/admin.macaroon"
delete_after: 3
delete-after: 3
provisioner:
proxmox:
read_only: false
read-only: false
qemu:
bios: "ovmf"
machine: "q35"
os_type: "l26"
os-type: "l26"
bridge: "vmbr0"
cpu: "host"
vlan: 100

View File

@ -1,6 +1,4 @@
use anyhow::Result;
use async_trait::async_trait;
mod model;
#[cfg(feature = "mysql")]
mod mysql;
@ -9,6 +7,8 @@ pub use model::*;
#[cfg(feature = "mysql")]
pub use mysql::*;
pub use async_trait::async_trait;
#[async_trait]
pub trait LNVpsDb: Sync + Send {
/// Migrate database
@ -65,6 +65,9 @@ pub trait LNVpsDb: Sync + Send {
/// List available IP Ranges
async fn list_ip_range(&self) -> Result<Vec<IpRange>>;
/// List available IP Ranges in a given region
async fn list_ip_range_in_region(&self, region_id: u64) -> Result<Vec<IpRange>>;
/// Get a VM cost plan by id
async fn get_cost_plan(&self, id: u64) -> Result<VmCostPlan>;

View File

@ -220,6 +220,7 @@ pub struct VmIpAssignment {
pub vm_id: u64,
pub ip_range_id: u64,
pub ip: String,
pub deleted: bool,
}
impl Display for VmIpAssignment {

View File

@ -171,7 +171,15 @@ impl LNVpsDb for LNVpsDbMysql {
}
async fn list_ip_range(&self) -> Result<Vec<IpRange>> {
sqlx::query_as("select * from ip_range")
sqlx::query_as("select * from ip_range where enabled = 1")
.fetch_all(&self.db)
.await
.map_err(Error::new)
}
async fn list_ip_range_in_region(&self, region_id: u64) -> Result<Vec<IpRange>> {
sqlx::query_as("select * from ip_range where region_id = ? and enabled = 1")
.bind(region_id)
.fetch_all(&self.db)
.await
.map_err(Error::new)

View File

@ -1,4 +1,4 @@
mod routes;
mod model;
mod routes;
pub use routes::routes;
pub use routes::routes;

View File

@ -1,9 +1,12 @@
use nostr::util::hex;
use crate::status::VmState;
use chrono::{DateTime, Utc};
use ipnetwork::IpNetwork;
use lnvps_db::VmHostRegion;
use nostr::util::hex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Serialize, Deserialize, JsonSchema)]
pub struct ApiVmStatus {
@ -48,15 +51,20 @@ impl From<lnvps_db::UserSshKey> for ApiUserSshKey {
pub struct ApiVmIpAssignment {
pub id: u64,
pub ip: String,
pub range: String,
pub gateway: String,
}
impl ApiVmIpAssignment {
pub fn from(ip: lnvps_db::VmIpAssignment, range: &lnvps_db::IpRange) -> Self {
pub fn from(ip: &lnvps_db::VmIpAssignment, range: &lnvps_db::IpRange) -> Self {
ApiVmIpAssignment {
id: ip.id,
ip: ip.ip,
range: range.cidr.clone(),
ip: IpNetwork::new(
IpNetwork::from_str(&ip.ip).unwrap().ip(),
IpNetwork::from_str(&range.cidr).unwrap().prefix(),
)
.unwrap()
.to_string(),
gateway: range.gateway.to_string(),
}
}
}

View File

@ -155,7 +155,7 @@ async fn vm_to_status(
let range = ip_ranges
.get(&i.ip_range_id)
.expect("ip range id not found");
ApiVmIpAssignment::from(i, range)
ApiVmIpAssignment::from(&i, range)
})
.collect(),
})
@ -457,6 +457,7 @@ async fn v1_get_payment(
} else {
return ApiData::err("Invalid payment id");
};
let payment = db.get_vm_payment(&id).await?;
let vm = db.get_vm(payment.vm_id).await?;
if vm.user_id != uid {

View File

@ -64,12 +64,10 @@ async fn main() -> Result<(), Error> {
} else {
None
};
let router = settings.router.as_ref().map(|r| r.get_router());
let router = settings.get_router()?;
let status = VmStateCache::new();
let worker_provisioner =
settings
.provisioner
.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
worker_provisioner.init().await?;
let mut worker = Worker::new(
@ -82,11 +80,11 @@ async fn main() -> Result<(), Error> {
let sender = worker.sender();
// send a startup notification
if let Some(admin) = &settings.smtp.and_then(|s| s.admin) {
if let Some(admin) = settings.smtp.as_ref().and_then(|s| s.admin) {
sender.send(WorkJob::SendNotification {
title: Some("Startup".to_string()),
message: "System is starting!".to_string(),
user_id: *admin,
user_id: admin,
})?;
}
@ -132,11 +130,8 @@ async fn main() -> Result<(), Error> {
}
});
let router = settings.router.as_ref().map(|r| r.get_router());
let provisioner =
settings
.provisioner
.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
let router = settings.get_router()?;
let provisioner = settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
let db: Box<dyn LNVpsDb> = Box::new(db.clone());
let pv: Box<dyn Provisioner> = Box::new(provisioner);

View File

@ -10,3 +10,6 @@ pub mod settings;
pub mod ssh_client;
pub mod status;
pub mod worker;
#[cfg(test)]
pub mod mocks;

337
src/mocks.rs Normal file
View File

@ -0,0 +1,337 @@
use crate::router::{ArpEntry, Router};
use crate::settings::NetworkPolicy;
use anyhow::anyhow;
use chrono::Utc;
use lnvps_db::{
async_trait, IpRange, LNVpsDb, User, UserSshKey, Vm, VmCostPlan, VmHost, VmHostDisk,
VmHostKind, VmHostRegion, VmIpAssignment, VmOsImage, VmPayment, VmTemplate,
};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct MockDb {
pub regions: Arc<Mutex<HashMap<u64, VmHostRegion>>>,
pub hosts: Arc<Mutex<HashMap<u64, VmHost>>>,
pub users: Arc<Mutex<HashMap<u64, User>>>,
pub vms: Arc<Mutex<HashMap<u64, Vm>>>,
pub ip_range: Arc<Mutex<HashMap<u64, IpRange>>>,
pub ip_assignments: Arc<Mutex<HashMap<u64, VmIpAssignment>>>,
}
impl MockDb {
pub fn empty() -> MockDb {
Self {
..Default::default()
}
}
}
impl Default for MockDb {
fn default() -> Self {
let mut regions = HashMap::new();
regions.insert(
1,
VmHostRegion {
id: 1,
name: "Mock".to_string(),
enabled: true,
},
);
let mut ip_ranges = HashMap::new();
ip_ranges.insert(
1,
IpRange {
id: 1,
cidr: "10.0.0.0/8".to_string(),
gateway: "10.0.0.1".to_string(),
enabled: true,
region_id: 1,
},
);
let mut hosts = HashMap::new();
hosts.insert(
1,
VmHost {
id: 1,
kind: VmHostKind::Proxmox,
region_id: 1,
name: "mock-host".to_string(),
ip: "https://localhost".to_string(),
cpu: 4,
memory: 8192,
enabled: true,
api_token: "".to_string(),
},
);
Self {
regions: Arc::new(Mutex::new(regions)),
ip_range: Arc::new(Mutex::new(ip_ranges)),
hosts: Arc::new(Mutex::new(hosts)),
users: Arc::new(Default::default()),
vms: Arc::new(Default::default()),
ip_assignments: Arc::new(Default::default()),
}
}
}
#[async_trait]
impl LNVpsDb for MockDb {
async fn migrate(&self) -> anyhow::Result<()> {
Ok(())
}
async fn upsert_user(&self, pubkey: &[u8; 32]) -> anyhow::Result<u64> {
let mut users = self.users.lock().await;
if let Some(e) = users.iter().find(|(k, u)| u.pubkey == *pubkey) {
Ok(*e.0)
} else {
let max = *users.keys().max().unwrap_or(&0);
users.insert(
max + 1,
User {
id: max + 1,
pubkey: pubkey.to_vec(),
created: Utc::now(),
email: None,
contact_nip4: false,
contact_nip17: false,
contact_email: false,
},
);
Ok(max + 1)
}
}
async fn get_user(&self, id: u64) -> anyhow::Result<User> {
let mut users = self.users.lock().await;
Ok(users.get(&id).ok_or(anyhow!("no user"))?.clone())
}
async fn update_user(&self, user: &User) -> anyhow::Result<()> {
let mut users = self.users.lock().await;
if let Some(u) = users.get_mut(&user.id) {
u.email = user.email.clone();
u.contact_email = user.contact_email.clone();
u.contact_nip17 = user.contact_nip17.clone();
u.contact_nip4 = user.contact_nip4.clone();
}
Ok(())
}
async fn delete_user(&self, id: u64) -> anyhow::Result<()> {
let mut users = self.users.lock().await;
users.remove(&id);
Ok(())
}
async fn insert_user_ssh_key(&self, new_key: &UserSshKey) -> anyhow::Result<u64> {
todo!()
}
async fn get_user_ssh_key(&self, id: u64) -> anyhow::Result<UserSshKey> {
todo!()
}
async fn delete_user_ssh_key(&self, id: u64) -> anyhow::Result<()> {
todo!()
}
async fn list_user_ssh_key(&self, user_id: u64) -> anyhow::Result<Vec<UserSshKey>> {
todo!()
}
async fn get_host_region(&self, id: u64) -> anyhow::Result<VmHostRegion> {
let regions = self.regions.lock().await;
Ok(regions.get(&id).ok_or(anyhow!("no region"))?.clone())
}
async fn list_hosts(&self) -> anyhow::Result<Vec<VmHost>> {
let hosts = self.hosts.lock().await;
Ok(hosts.values().filter(|h| h.enabled).cloned().collect())
}
async fn get_host(&self, id: u64) -> anyhow::Result<VmHost> {
let hosts = self.hosts.lock().await;
Ok(hosts.get(&id).ok_or(anyhow!("no host"))?.clone())
}
async fn update_host(&self, host: &VmHost) -> anyhow::Result<()> {
let mut hosts = self.hosts.lock().await;
if let Some(h) = hosts.get_mut(&host.id) {
h.enabled = host.enabled;
h.cpu = host.cpu;
h.memory = host.memory;
}
Ok(())
}
async fn list_host_disks(&self, host_id: u64) -> anyhow::Result<Vec<VmHostDisk>> {
todo!()
}
async fn get_os_image(&self, id: u64) -> anyhow::Result<VmOsImage> {
todo!()
}
async fn list_os_image(&self) -> anyhow::Result<Vec<VmOsImage>> {
todo!()
}
async fn get_ip_range(&self, id: u64) -> anyhow::Result<IpRange> {
let ip_range = self.ip_range.lock().await;
Ok(ip_range.get(&id).ok_or(anyhow!("no ip range"))?.clone())
}
async fn list_ip_range(&self) -> anyhow::Result<Vec<IpRange>> {
let ip_range = self.ip_range.lock().await;
Ok(ip_range.values().filter(|r| r.enabled).cloned().collect())
}
async fn list_ip_range_in_region(&self, region_id: u64) -> anyhow::Result<Vec<IpRange>> {
let ip_range = self.ip_range.lock().await;
Ok(ip_range
.values()
.filter(|r| r.enabled && r.region_id == region_id)
.cloned()
.collect())
}
async fn get_cost_plan(&self, id: u64) -> anyhow::Result<VmCostPlan> {
todo!()
}
async fn get_vm_template(&self, id: u64) -> anyhow::Result<VmTemplate> {
todo!()
}
async fn list_vm_templates(&self) -> anyhow::Result<Vec<VmTemplate>> {
todo!()
}
async fn list_vms(&self) -> anyhow::Result<Vec<Vm>> {
todo!()
}
async fn list_expired_vms(&self) -> anyhow::Result<Vec<Vm>> {
todo!()
}
async fn list_user_vms(&self, id: u64) -> anyhow::Result<Vec<Vm>> {
todo!()
}
async fn get_vm(&self, vm_id: u64) -> anyhow::Result<Vm> {
todo!()
}
async fn insert_vm(&self, vm: &Vm) -> anyhow::Result<u64> {
todo!()
}
async fn delete_vm(&self, vm_id: u64) -> anyhow::Result<()> {
todo!()
}
async fn update_vm(&self, vm: &Vm) -> anyhow::Result<()> {
todo!()
}
async fn insert_vm_ip_assignment(&self, ip_assignment: &VmIpAssignment) -> anyhow::Result<u64> {
let mut ip_assignments = self.ip_assignments.lock().await;
let max = *ip_assignments.keys().max().unwrap_or(&0);
ip_assignments.insert(
max + 1,
VmIpAssignment {
id: max + 1,
vm_id: ip_assignment.vm_id,
ip_range_id: ip_assignment.ip_range_id,
ip: ip_assignment.ip.clone(),
deleted: false,
},
);
Ok(max + 1)
}
async fn list_vm_ip_assignments(&self, vm_id: u64) -> anyhow::Result<Vec<VmIpAssignment>> {
let ip_assignments = self.ip_assignments.lock().await;
Ok(ip_assignments
.values()
.filter(|a| a.vm_id == vm_id && !a.deleted)
.cloned()
.collect())
}
async fn list_vm_ip_assignments_in_range(
&self,
range_id: u64,
) -> anyhow::Result<Vec<VmIpAssignment>> {
let ip_assignments = self.ip_assignments.lock().await;
Ok(ip_assignments
.values()
.filter(|a| a.ip_range_id == range_id && !a.deleted)
.cloned()
.collect())
}
async fn delete_vm_ip_assignment(&self, vm_id: u64) -> anyhow::Result<()> {
let mut ip_assignments = self.ip_assignments.lock().await;
for ip_assignment in ip_assignments.values_mut() {
if ip_assignment.vm_id == vm_id {
ip_assignment.deleted = true;
}
}
Ok(())
}
async fn list_vm_payment(&self, vm_id: u64) -> anyhow::Result<Vec<VmPayment>> {
todo!()
}
async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> {
todo!()
}
async fn get_vm_payment(&self, id: &Vec<u8>) -> anyhow::Result<VmPayment> {
todo!()
}
async fn update_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> {
todo!()
}
async fn vm_payment_paid(&self, id: &VmPayment) -> anyhow::Result<()> {
todo!()
}
async fn last_paid_invoice(&self) -> anyhow::Result<Option<VmPayment>> {
todo!()
}
}
struct MockRouter {
pub policy: NetworkPolicy,
}
#[async_trait]
impl Router for MockRouter {
async fn list_arp_entry(&self) -> anyhow::Result<Vec<ArpEntry>> {
todo!()
}
async fn add_arp_entry(
&self,
ip: IpAddr,
mac: &str,
interface: &str,
comment: Option<&str>,
) -> anyhow::Result<()> {
todo!()
}
async fn remove_arp_entry(&self, id: &str) -> anyhow::Result<()> {
todo!()
}
}

View File

@ -4,9 +4,9 @@ use crate::host::proxmox::{
ConfigureVm, CreateVm, DownloadUrlRequest, ProxmoxClient, ResizeDiskRequest, StorageContent,
VmBios, VmConfig,
};
use crate::provisioner::Provisioner;
use crate::provisioner::{NetworkProvisioner, Provisioner, ProvisionerMethod};
use crate::router::Router;
use crate::settings::{QemuConfig, SshConfig};
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, QemuConfig, SshConfig};
use crate::ssh_client::SshClient;
use anyhow::{bail, Result};
use chrono::{Days, Months, Utc};
@ -25,18 +25,21 @@ use rocket::futures::{SinkExt, StreamExt};
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::ops::Add;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
pub struct LNVpsProvisioner {
db: Box<dyn LNVpsDb>,
db: Arc<dyn LNVpsDb>,
router: Option<Box<dyn Router>>,
lnd: Client,
rates: ExchangeRateCache,
read_only: bool,
config: QemuConfig,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>,
}
@ -44,6 +47,7 @@ impl LNVpsProvisioner {
pub fn new(
read_only: bool,
config: QemuConfig,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>,
db: impl LNVpsDb + 'static,
router: Option<impl Router + 'static>,
@ -51,9 +55,10 @@ impl LNVpsProvisioner {
rates: ExchangeRateCache,
) -> Self {
Self {
db: Box::new(db),
db: Arc::new(db),
router: router.map(|r| Box::new(r) as Box<dyn Router>),
lnd,
network_policy,
rates,
config,
read_only,
@ -73,14 +78,9 @@ impl LNVpsProvisioner {
}
}
async fn get_vm_config(&self, vm: &Vm) -> Result<VmConfig> {
async fn get_vm_config(&self, vm: &Vm, ips: &Vec<VmIpAssignment>) -> Result<VmConfig> {
let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?;
let mut ips = self.db.list_vm_ip_assignments(vm.id).await?;
if ips.is_empty() {
ips = self.allocate_ips(vm.id).await?;
}
let ip_range_ids: HashSet<u64> = ips.iter().map(|i| i.ip_range_id).collect();
let ip_ranges: Vec<_> = ip_range_ids
.iter()
@ -148,6 +148,30 @@ impl LNVpsProvisioner {
..Default::default()
})
}
async fn save_ip_assignment(&self, vm: &Vm, assignment: &VmIpAssignment) -> Result<()> {
// apply network policy
match &self.network_policy.access {
NetworkAccessPolicy::StaticArp { interface } => {
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(
IpAddr::from_str(&assignment.ip)?,
&vm.mac_address,
interface,
Some(&format!("VM{}", vm.id)),
)
.await?;
} else {
bail!("No router found to apply static arp entry!")
}
}
_ => {}
}
// save to db
self.db.insert_vm_ip_assignment(&assignment).await?;
Ok(())
}
}
#[async_trait]
@ -303,69 +327,29 @@ impl Provisioner for LNVpsProvisioner {
async fn allocate_ips(&self, vm_id: u64) -> Result<Vec<VmIpAssignment>> {
let vm = self.db.get_vm(vm_id).await?;
let existing_ips = self.db.list_vm_ip_assignments(vm_id).await?;
if !existing_ips.is_empty() {
return Ok(existing_ips);
}
let template = self.db.get_vm_template(vm.template_id).await?;
let ips = self.db.list_vm_ip_assignments(vm.id).await?;
let prov = NetworkProvisioner::new(
ProvisionerMethod::Random,
self.network_policy.clone(),
self.db.clone(),
);
if !ips.is_empty() {
bail!("IP resources are already assigned");
}
let ip = prov.pick_ip_for_region(template.region_id).await?;
let assignment = VmIpAssignment {
id: 0,
vm_id,
ip_range_id: ip.range_id,
ip: ip.ip.to_string(),
deleted: false,
};
let ip_ranges = self.db.list_ip_range().await?;
let ip_ranges: Vec<IpRange> = ip_ranges
.into_iter()
.filter(|i| i.region_id == template.region_id && i.enabled)
.collect();
if ip_ranges.is_empty() {
bail!("No ip range found in this region");
}
let mut ret = vec![];
// Try all ranges
// TODO: pick round-robin ranges
// TODO: pick one of each type
'ranges: for range in ip_ranges {
let range_cidr: IpNetwork = range.cidr.parse()?;
let ips = self.db.list_vm_ip_assignments_in_range(range.id).await?;
let ips: HashSet<IpNetwork> = ips.iter().map_while(|i| i.ip.parse().ok()).collect();
// pick an IP at random
let cidr: Vec<IpAddr> = {
let mut rng = rand::thread_rng();
range_cidr.iter().choose(&mut rng).into_iter().collect()
};
for ip in cidr {
let ip_net = IpNetwork::new(ip, range_cidr.prefix())?;
if !ips.contains(&ip_net) {
info!("Attempting to allocate IP for {vm_id} to {ip}");
let mut assignment = VmIpAssignment {
id: 0,
vm_id,
ip_range_id: range.id,
ip: ip_net.to_string(),
..Default::default()
};
// add arp entry for router
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(ip, &vm.mac_address, Some(&format!("VM{}", vm.id)))
.await?;
}
let id = self.db.insert_vm_ip_assignment(&assignment).await?;
assignment.id = id;
ret.push(assignment);
break 'ranges;
}
}
}
if ret.is_empty() {
bail!("No ip ranges found in this region");
}
Ok(ret)
self.save_ip_assignment(&vm, &assignment).await?;
Ok(vec![assignment])
}
async fn spawn_vm(&self, vm_id: u64) -> Result<()> {
@ -378,16 +362,21 @@ impl Provisioner for LNVpsProvisioner {
let client = get_host_client(&host)?;
let vm_id = 100 + vm.id as i32;
// setup network by allocating some IP space
let ips = self.allocate_ips(vm.id).await?;
// create VM
let config = self.get_vm_config(&vm, &ips).await?;
let t_create = client
.create_vm(CreateVm {
node: host.name.clone(),
vm_id,
config: self.get_vm_config(&vm).await?,
config,
})
.await?;
client.wait_for_task(&t_create).await?;
// save
// import the disk
// TODO: find a way to avoid using SSH
if let Some(ssh_config) = &self.ssh {
@ -451,8 +440,10 @@ impl Provisioner for LNVpsProvisioner {
.await?;
client.wait_for_task(&j_resize).await?;
let j_start = client.start_vm(&host.name, vm_id as u64).await?;
client.wait_for_task(&j_start).await?;
// try start, otherwise ignore error (maybe its already running)
if let Ok(j_start) = client.start_vm(&host.name, vm_id as u64).await {
client.wait_for_task(&j_start).await?;
}
Ok(())
}
@ -549,6 +540,7 @@ impl Provisioner for LNVpsProvisioner {
async fn patch_vm(&self, vm_id: u64) -> Result<()> {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let ips = self.db.list_vm_ip_assignments(vm.id).await?;
let client = get_host_client(&host)?;
let host_vm_id = vm.id + 100;
@ -562,7 +554,7 @@ impl Provisioner for LNVpsProvisioner {
scsi_0: None,
scsi_1: None,
efi_disk_0: None,
..self.get_vm_config(&vm).await?
..self.get_vm_config(&vm, &ips).await?
},
})
.await?;

View File

@ -4,7 +4,11 @@ use rocket::async_trait;
use tokio::net::TcpStream;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
pub mod lnvps;
mod lnvps;
mod network;
pub use lnvps::*;
pub use network::*;
#[async_trait]
pub trait Provisioner: Send + Sync {

136
src/provisioner/network.rs Normal file
View File

@ -0,0 +1,136 @@
use crate::settings::NetworkPolicy;
use anyhow::{bail, Result};
use ipnetwork::IpNetwork;
use lnvps_db::LNVpsDb;
use rand::prelude::IteratorRandom;
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub enum ProvisionerMethod {
Sequential,
Random,
}
#[derive(Debug, Clone, Copy)]
pub struct AvailableIp {
pub ip: IpAddr,
pub range_id: u64,
pub region_id: u64,
}
#[derive(Clone)]
pub struct NetworkProvisioner {
method: ProvisionerMethod,
settings: NetworkPolicy,
db: Arc<dyn LNVpsDb>,
}
impl NetworkProvisioner {
pub fn new(method: ProvisionerMethod, settings: NetworkPolicy, db: Arc<dyn LNVpsDb>) -> Self {
Self {
method,
settings,
db,
}
}
/// Pick an IP from one of the available ip ranges
/// This method MUST return a free IP which can be used
pub async fn pick_ip_for_region(&self, region_id: u64) -> Result<AvailableIp> {
let ip_ranges = self.db.list_ip_range_in_region(region_id).await?;
if ip_ranges.is_empty() {
bail!("No ip range found in this region");
}
for range in ip_ranges {
let range_cidr: IpNetwork = range.cidr.parse()?;
let ips = self.db.list_vm_ip_assignments_in_range(range.id).await?;
let ips: HashSet<IpAddr> = ips.iter().map_while(|i| i.ip.parse().ok()).collect();
// pick an IP at random
let ip_pick = {
let first_ip = range_cidr.iter().next().unwrap();
let last_ip = range_cidr.iter().last().unwrap();
match self.method {
ProvisionerMethod::Sequential => range_cidr
.iter()
.find(|i| *i != first_ip && *i != last_ip && !ips.contains(i)),
ProvisionerMethod::Random => {
let mut rng = rand::rng();
loop {
if let Some(i) = range_cidr.iter().choose(&mut rng) {
if i != first_ip && i != last_ip && !ips.contains(&i) {
break Some(i);
}
} else {
break None;
}
}
}
}
};
if let Some(ip_pick) = ip_pick {
return Ok(AvailableIp {
range_id: range.id,
ip: ip_pick,
region_id,
});
}
}
bail!("No IPs available in this region");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mocks::*;
use crate::settings::NetworkAccessPolicy;
use lnvps_db::VmIpAssignment;
use std::str::FromStr;
#[tokio::test]
async fn pick_seq_ip_for_region_test() {
let db: Arc<dyn LNVpsDb> = Arc::new(MockDb::default());
let mgr = NetworkProvisioner::new(
ProvisionerMethod::Sequential,
NetworkPolicy {
access: NetworkAccessPolicy::Auto,
},
db.clone(),
);
let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db");
assert_eq!(1, ip.region_id);
assert_eq!(IpAddr::from_str("10.0.0.1").unwrap(), ip.ip);
db.insert_vm_ip_assignment(&VmIpAssignment {
id: 0,
vm_id: 0,
ip_range_id: ip.range_id,
ip: ip.ip.to_string(),
deleted: false,
})
.await
.expect("Could not insert vm ip");
let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db");
assert_eq!(IpAddr::from_str("10.0.0.2").unwrap(), ip.ip);
}
#[tokio::test]
async fn pick_rng_ip_for_region_test() {
let db: Arc<dyn LNVpsDb> = Arc::new(MockDb::default());
let mgr = NetworkProvisioner::new(
ProvisionerMethod::Random,
NetworkPolicy {
access: NetworkAccessPolicy::Auto,
},
db,
);
let ip = mgr.pick_ip_for_region(1).await.expect("No ip found in db");
assert_eq!(1, ip.region_id);
}
}

View File

@ -14,11 +14,10 @@ pub struct MikrotikRouter {
username: String,
password: String,
client: Client,
arp_interface: String,
}
impl MikrotikRouter {
pub fn new(url: &str, username: &str, password: &str, arp_interface: &str) -> Self {
pub fn new(url: &str, username: &str, password: &str) -> Self {
Self {
url: url.parse().unwrap(),
username: username.to_string(),
@ -27,7 +26,6 @@ impl MikrotikRouter {
.danger_accept_invalid_certs(true)
.build()
.unwrap(),
arp_interface: arp_interface.to_string(),
}
}
@ -73,7 +71,13 @@ impl Router for MikrotikRouter {
Ok(rsp)
}
async fn add_arp_entry(&self, ip: IpAddr, mac: &str, comment: Option<&str>) -> Result<()> {
async fn add_arp_entry(
&self,
ip: IpAddr,
mac: &str,
arp_interface: &str,
comment: Option<&str>,
) -> Result<()> {
let _rsp: ArpEntry = self
.req(
Method::PUT,
@ -81,7 +85,7 @@ impl Router for MikrotikRouter {
ArpEntry {
address: ip.to_string(),
mac_address: Some(mac.to_string()),
interface: self.arp_interface.to_string(),
interface: arp_interface.to_string(),
comment: comment.map(|c| c.to_string()),
..Default::default()
},

View File

@ -13,7 +13,13 @@ use std::net::IpAddr;
#[async_trait]
pub trait Router: Send + Sync {
async fn list_arp_entry(&self) -> Result<Vec<ArpEntry>>;
async fn add_arp_entry(&self, ip: IpAddr, mac: &str, comment: Option<&str>) -> Result<()>;
async fn add_arp_entry(
&self,
ip: IpAddr,
mac: &str,
interface: &str,
comment: Option<&str>,
) -> Result<()>;
async fn remove_arp_entry(&self, id: &str) -> Result<()>;
}
@ -35,3 +41,12 @@ pub struct ArpEntry {
mod mikrotik;
#[cfg(feature = "mikrotik")]
pub use mikrotik::*;
#[cfg(test)]
mod tests {
use super::*;
use crate::settings::NetworkPolicy;
#[test]
fn provision_ips_with_arp() {}
}

View File

@ -1,21 +1,27 @@
use crate::exchange::ExchangeRateCache;
use crate::provisioner::lnvps::LNVpsProvisioner;
use crate::provisioner::LNVpsProvisioner;
use crate::provisioner::Provisioner;
use crate::router::{MikrotikRouter, Router};
use anyhow::{bail, Result};
use fedimint_tonic_lnd::Client;
use lnvps_db::LNVpsDb;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Settings {
pub listen: Option<String>,
pub db: String,
pub lnd: LndConfig,
/// Main control process impl
/// Provisioning profiles
pub provisioner: ProvisionerConfig,
/// Network policy
#[serde(default)]
pub network_policy: NetworkPolicy,
/// Number of days after an expired VM is deleted
pub delete_after: u16,
@ -25,7 +31,10 @@ pub struct Settings {
/// Network router config
pub router: Option<RouterConfig>,
/// Nostr config for sending DM's
/// DNS configurations for PTR records
pub dns: Option<DnsServerConfig>,
/// Nostr config for sending DMs
pub nostr: Option<NostrConfig>,
}
@ -43,18 +52,55 @@ pub struct NostrConfig {
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
#[serde(rename_all = "kebab-case")]
pub enum RouterConfig {
Mikrotik {
url: String,
username: String,
password: String,
Mikrotik(ApiConfig),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum DnsServerConfig {
Cloudflare { api: ApiConfig, zone_id: String },
}
/// Generic remote API credentials
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ApiConfig {
/// unique ID of this router, used in references
pub id: String,
/// http://<my-router>
pub url: String,
/// Login credentials used for this router
pub credentials: Credentials,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum Credentials {
UsernamePassword { username: String, password: String },
ApiToken { token: String },
}
/// Policy that determines how packets arrive at the VM
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum NetworkAccessPolicy {
/// No special procedure required for packets to arrive
#[default]
Auto,
/// ARP entries are added statically on the access router
StaticArp {
/// Interface used to add arp entries
arp_interface: String,
interface: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub struct NetworkPolicy {
pub access: NetworkAccessPolicy,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SmtpConfig {
/// Admin user id, for sending system notifications
@ -74,8 +120,9 @@ pub struct SmtpConfig {
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
#[serde(rename_all = "kebab-case")]
pub enum ProvisionerConfig {
#[serde(rename_all = "kebab-case")]
Proxmox {
/// Readonly mode, don't spawn any VM's
read_only: bool,
@ -95,27 +142,23 @@ pub struct SshConfig {
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct QemuConfig {
/// Machine type (q35)
pub machine: String,
/// OS Type
pub os_type: String,
/// Network bridge used for the networking interface
pub bridge: String,
/// CPU type
pub cpu: String,
/// VLAN tag all spawned VM's
pub vlan: Option<u16>,
/// Enable virtualization inside VM
pub kvm: bool,
}
impl ProvisionerConfig {
impl Settings {
pub fn get_provisioner(
&self,
db: impl LNVpsDb + 'static,
@ -123,7 +166,7 @@ impl ProvisionerConfig {
lnd: Client,
exchange: ExchangeRateCache,
) -> impl Provisioner + 'static {
match self {
match &self.provisioner {
ProvisionerConfig::Proxmox {
qemu,
ssh,
@ -131,6 +174,7 @@ impl ProvisionerConfig {
} => LNVpsProvisioner::new(
*read_only,
qemu.clone(),
self.network_policy.clone(),
ssh.clone(),
db,
router,
@ -139,17 +183,15 @@ impl ProvisionerConfig {
),
}
}
}
impl RouterConfig {
pub fn get_router(&self) -> impl Router + 'static {
match self {
RouterConfig::Mikrotik {
url,
username,
password,
arp_interface,
} => MikrotikRouter::new(url, username, password, arp_interface),
pub fn get_router<'a>(&'a self) -> Result<Option<impl Router + 'static>> {
match &self.router {
Some(RouterConfig::Mikrotik(api)) => match &api.credentials {
Credentials::UsernamePassword { username, password } => {
Ok(Some(MikrotikRouter::new(&api.url, username, password)))
}
_ => bail!("Only username/password is supported for Mikrotik routers"),
},
_ => Ok(None),
}
}
}

View File

@ -1,9 +1,9 @@
use anyhow::Result;
use rocket::serde::Deserialize;
use schemars::JsonSchema;
use serde::Serialize;
use std::collections::HashMap;
use std::sync::Arc;
use rocket::serde::Deserialize;
use schemars::JsonSchema;
use tokio::sync::RwLock;
#[derive(Clone, Serialize, Deserialize, Default, JsonSchema)]

View File

@ -327,12 +327,7 @@ impl Worker {
if let Err(e) = self.check_vm(*vm_id).await {
error!("Failed to check VM {}: {}", vm_id, e);
self.queue_admin_notification(
format!(
"Failed to check VM {}:\n{:?}\n{}",
vm_id,
&job,
e
),
format!("Failed to check VM {}:\n{:?}\n{}", vm_id, &job, e),
Some("Job Failed".to_string()),
)?
}
@ -348,11 +343,7 @@ impl Worker {
{
error!("Failed to send notification {}: {}", user_id, e);
self.queue_admin_notification(
format!(
"Failed to send notification:\n{:?}\n{}",
&job,
e
),
format!("Failed to send notification:\n{:?}\n{}", &job, e),
Some("Job Failed".to_string()),
)?
}