refactor: make services generic for mock testing

This commit is contained in:
2025-02-28 15:51:09 +00:00
parent f105b90849
commit e84e6afd54
17 changed files with 605 additions and 353 deletions

View File

@ -1,28 +1,23 @@
use crate::exchange::{ExchangeRateCache, Ticker};
use crate::exchange::{ExchangeRateService, Ticker};
use crate::host::get_host_client;
use crate::host::proxmox::{
ConfigureVm, CreateVm, DownloadUrlRequest, ProxmoxClient, ResizeDiskRequest, StorageContent,
VmBios, VmConfig,
ConfigureVm, CreateVm, DownloadUrlRequest, ImportDiskImageRequest, ProxmoxClient,
ResizeDiskRequest, StorageContent, VmConfig,
};
use crate::lightning::{AddInvoiceRequest, LightningNode};
use crate::provisioner::{NetworkProvisioner, Provisioner, ProvisionerMethod};
use crate::router::Router;
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, QemuConfig, SshConfig};
use crate::ssh_client::SshClient;
use crate::settings::{
NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings,
};
use anyhow::{bail, Result};
use chrono::{Days, Months, Utc};
use fedimint_tonic_lnd::lnrpc::Invoice;
use fedimint_tonic_lnd::tonic::async_trait;
use fedimint_tonic_lnd::Client;
use ipnetwork::IpNetwork;
use lnvps_db::{DiskType, IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment};
use lnvps_db::{DiskType, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment};
use log::{debug, info, warn};
use nostr::util::hex;
use rand::random;
use rand::seq::IteratorRandom;
use reqwest::Url;
use rocket::futures::future::join_all;
use rocket::futures::{SinkExt, StreamExt};
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::ops::Add;
use std::str::FromStr;
@ -32,37 +27,38 @@ use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
/// Main provisioner class for LNVPS
///
/// Does all the hard work and logic for creating / expiring VM's
pub struct LNVpsProvisioner {
db: Arc<dyn LNVpsDb>,
router: Option<Box<dyn Router>>,
lnd: Client,
rates: ExchangeRateCache,
read_only: bool,
config: QemuConfig,
db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
rates: Arc<dyn ExchangeRateService>,
router: Option<Arc<dyn Router>>,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>,
provisioner_config: ProvisionerConfig,
}
impl LNVpsProvisioner {
pub fn new(
read_only: bool,
config: QemuConfig,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>,
db: impl LNVpsDb + 'static,
router: Option<impl Router + 'static>,
lnd: Client,
rates: ExchangeRateCache,
settings: Settings,
db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
rates: Arc<dyn ExchangeRateService>,
) -> Self {
Self {
db: Arc::new(db),
router: router.map(|r| Box::new(r) as Box<dyn Router>),
lnd,
network_policy,
db,
node,
rates,
config,
read_only,
ssh,
router: settings
.get_router()
.expect("router config"),
network_policy: settings.network_policy,
provisioner_config: settings.provisioner,
read_only: settings.read_only,
}
}
@ -78,98 +74,24 @@ impl LNVpsProvisioner {
}
}
async fn get_vm_config(&self, vm: &Vm, ips: &Vec<VmIpAssignment>) -> Result<VmConfig> {
let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?;
let ip_range_ids: HashSet<u64> = ips.iter().map(|i| i.ip_range_id).collect();
let ip_ranges: Vec<_> = ip_range_ids
.iter()
.map(|i| self.db.get_ip_range(*i))
.collect();
let ip_ranges: HashMap<u64, IpRange> = join_all(ip_ranges)
.await
.into_iter()
.filter_map(Result::ok)
.map(|i| (i.id, i))
.collect();
let mut ip_config = ips
.iter()
.map_while(|ip| {
if let Ok(net) = ip.ip.parse::<IpNetwork>() {
Some(match net {
IpNetwork::V4(addr) => {
let range = ip_ranges.get(&ip.ip_range_id)?;
format!("ip={},gw={}", addr, range.gateway)
}
IpNetwork::V6(addr) => format!("ip6={}", addr),
})
} else {
None
}
})
.collect::<Vec<_>>();
ip_config.push("ip6=auto".to_string());
let mut net = vec![
format!("virtio={}", vm.mac_address),
format!("bridge={}", self.config.bridge),
];
if let Some(t) = self.config.vlan {
net.push(format!("tag={}", t));
}
let drives = self.db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
let template = self.db.get_vm_template(vm.template_id).await?;
Ok(VmConfig {
cpu: Some(self.config.cpu.clone()),
kvm: Some(self.config.kvm),
ip_config: Some(ip_config.join(",")),
machine: Some(self.config.machine.clone()),
net: Some(net.join(",")),
os_type: Some(self.config.os_type.clone()),
on_boot: Some(true),
bios: Some(VmBios::OVMF),
boot: Some("order=scsi0".to_string()),
cores: Some(template.cpu as i32),
memory: Some((template.memory / 1024 / 1024).to_string()),
scsi_hw: Some("virtio-scsi-pci".to_string()),
serial_0: Some("socket".to_string()),
scsi_1: Some(format!("{}:cloudinit", &drive.name)),
ssh_keys: Some(urlencoding::encode(&ssh_key.key_data).to_string()),
efi_disk_0: Some(format!("{}:0,efitype=4m", &drive.name)),
..Default::default()
})
}
async fn save_ip_assignment(&self, vm: &Vm, assignment: &VmIpAssignment) -> Result<()> {
// apply network policy
match &self.network_policy.access {
NetworkAccessPolicy::StaticArp { interface } => {
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(
IpAddr::from_str(&assignment.ip)?,
&vm.mac_address,
interface,
Some(&format!("VM{}", vm.id)),
)
.await?;
} else {
bail!("No router found to apply static arp entry!")
}
if let NetworkAccessPolicy::StaticArp { interface } = &self.network_policy.access {
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(
IpAddr::from_str(&assignment.ip)?,
&vm.mac_address,
interface,
Some(&format!("VM{}", vm.id)),
)
.await?;
} else {
bail!("No router found to apply static arp entry!")
}
_ => {}
}
// save to db
self.db.insert_vm_ip_assignment(&assignment).await?;
self.db.insert_vm_ip_assignment(assignment).await?;
Ok(())
}
}
@ -180,7 +102,7 @@ impl Provisioner for LNVpsProvisioner {
// tell hosts to download images
let hosts = self.db.list_hosts().await?;
for host in hosts {
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let iso_storage = Self::get_iso_storage(&host.name, &client).await?;
let files = client.list_storage_files(&host.name, &iso_storage).await?;
@ -284,7 +206,7 @@ impl Provisioner for LNVpsProvisioner {
};
const BTC_SATS: f64 = 100_000_000.0;
const INVOICE_EXPIRE: i64 = 3600;
const INVOICE_EXPIRE: u32 = 3600;
let ticker = Ticker::btc_rate(cost_plan.currency.as_str())?;
let rate = if let Some(r) = self.rates.get_rate(ticker).await {
@ -294,27 +216,23 @@ impl Provisioner for LNVpsProvisioner {
};
let cost_btc = cost_plan.amount as f32 / rate;
let cost_msat = (cost_btc as f64 * BTC_SATS) as i64 * 1000;
let cost_msat = (cost_btc as f64 * BTC_SATS) as u64 * 1000;
info!("Creating invoice for {vm_id} for {} sats", cost_msat / 1000);
let mut lnd = self.lnd.clone();
let invoice = lnd
.lightning()
.add_invoice(Invoice {
memo: format!("VM renewal {vm_id} to {new_expire}"),
value_msat: cost_msat,
expiry: INVOICE_EXPIRE,
..Default::default()
let invoice = self
.node
.add_invoice(AddInvoiceRequest {
memo: Some(format!("VM renewal {vm_id} to {new_expire}")),
amount: cost_msat,
expire: Some(INVOICE_EXPIRE),
})
.await?;
let invoice = invoice.into_inner();
let vm_payment = VmPayment {
id: invoice.r_hash.clone(),
id: hex::decode(invoice.payment_hash)?,
vm_id,
created: Utc::now(),
expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE as u64)),
amount: cost_msat as u64,
invoice: invoice.payment_request.clone(),
amount: cost_msat,
invoice: invoice.pr,
time_value: (new_expire - vm.expires).num_seconds() as u64,
is_paid: false,
rate,
@ -332,13 +250,13 @@ impl Provisioner for LNVpsProvisioner {
return Ok(existing_ips);
}
let template = self.db.get_vm_template(vm.template_id).await?;
// Use random network provisioner
let prov = NetworkProvisioner::new(
ProvisionerMethod::Random,
self.network_policy.clone(),
self.db.clone(),
);
let template = self.db.get_vm_template(vm.template_id).await?;
let ip = prov.pick_ip_for_region(template.region_id).await?;
let assignment = VmIpAssignment {
id: 0,
@ -352,21 +270,25 @@ impl Provisioner for LNVpsProvisioner {
Ok(vec![assignment])
}
/// Create a vm on the host as configured by the template
async fn spawn_vm(&self, vm_id: u64) -> Result<()> {
if self.read_only {
bail!("Cant spawn VM's in read-only mode");
bail!("Cant spawn VM's in read-only mode")
}
let vm = self.db.get_vm(vm_id).await?;
let template = self.db.get_vm_template(vm.template_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let image = self.db.get_os_image(vm.image_id).await?;
// TODO: remove +100 nonsense (proxmox specific)
let vm_id = 100 + vm.id as i32;
// setup network by allocating some IP space
let ips = self.allocate_ips(vm.id).await?;
// create VM
let config = self.get_vm_config(&vm, &ips).await?;
let config = client.make_vm_config(&self.db, &vm, &ips).await?;
let t_create = client
.create_vm(CreateVm {
node: host.name.clone(),
@ -376,60 +298,31 @@ impl Provisioner for LNVpsProvisioner {
.await?;
client.wait_for_task(&t_create).await?;
// save
// import the disk
// TODO: find a way to avoid using SSH
if let Some(ssh_config) = &self.ssh {
let image = self.db.get_os_image(vm.image_id).await?;
let host_addr: Url = host.ip.parse()?;
let mut ses = SshClient::new()?;
ses.connect(
(host_addr.host().unwrap().to_string(), 22),
&ssh_config.user,
&ssh_config.key,
)
// TODO: pick disk based on available space
// TODO: build external module to manage picking disks
// pick disk
let drives = self.db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
// TODO: remove scsi0 terms (proxmox specific)
// import primary disk from image (scsi0)?
client
.import_disk_image(ImportDiskImageRequest {
vm_id,
node: host.name.clone(),
storage: drive.name.clone(),
disk: "scsi0".to_string(),
image: image.filename()?,
is_ssd: matches!(drive.kind, DiskType::SSD),
})
.await?;
let drives = self.db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
// Disk import args
let mut disk_args: HashMap<&str, String> = HashMap::new();
disk_args.insert(
"import-from",
format!("/var/lib/vz/template/iso/{}", image.filename()?),
);
// If disk is SSD, enable discard + ssd options
if matches!(drive.kind, DiskType::SSD) {
disk_args.insert("discard", "on".to_string());
disk_args.insert("ssd", "1".to_string());
}
let cmd = format!(
"/usr/sbin/qm set {} --scsi0 {}:0,{}",
vm_id,
&drive.name,
disk_args
.into_iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(",")
);
let (code, rsp) = ses.execute(cmd.as_str()).await?;
info!("{}", rsp);
if code != 0 {
bail!("Failed to import disk, exit-code {}", code);
}
} else {
bail!("Cannot complete, no method available to import disk, consider configuring ssh")
}
// resize disk
// TODO: remove scsi0 terms (proxmox specific)
// resize disk to match template
let j_resize = client
.resize_disk(ResizeDiskRequest {
node: host.name.clone(),
@ -452,7 +345,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let j_start = client.start_vm(&host.name, vm.id + 100).await?;
client.wait_for_task(&j_start).await?;
Ok(())
@ -462,7 +355,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let j_start = client.shutdown_vm(&host.name, vm.id + 100).await?;
client.wait_for_task(&j_start).await?;
@ -473,7 +366,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let j_start = client.reset_vm(&host.name, vm.id + 100).await?;
client.wait_for_task(&j_start).await?;
@ -517,7 +410,7 @@ impl Provisioner for LNVpsProvisioner {
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let host_vm_id = vm.id + 100;
let term = client.terminal_proxy(&host.name, host_vm_id).await?;
@ -541,7 +434,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let ips = self.db.list_vm_ip_assignments(vm.id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let host_vm_id = vm.id + 100;
let t = client
@ -554,7 +447,7 @@ impl Provisioner for LNVpsProvisioner {
scsi_0: None,
scsi_1: None,
efi_disk_0: None,
..self.get_vm_config(&vm, &ips).await?
..client.make_vm_config(&self.db, &vm, &ips).await?
},
})
.await?;
@ -562,3 +455,50 @@ impl Provisioner for LNVpsProvisioner {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exchange::DefaultRateCache;
use crate::mocks::{MockDb, MockNode};
use crate::settings::{LndConfig, ProvisionerConfig};
#[tokio::test]
async fn test_basic_provisioner() {
let settings = Settings {
listen: None,
db: "".to_string(),
lnd: LndConfig {
url: "".to_string(),
cert: Default::default(),
macaroon: Default::default(),
},
read_only: false,
provisioner: ProvisionerConfig::Proxmox {
qemu: QemuConfig {
machine: "q35".to_string(),
os_type: "linux26".to_string(),
bridge: "vmbr1".to_string(),
cpu: "kvm64".to_string(),
vlan: None,
kvm: false,
},
ssh: None,
},
network_policy: NetworkPolicy {
access: NetworkAccessPolicy::StaticArp {
interface: "bridge1".to_string(),
},
},
delete_after: 0,
smtp: None,
router: None,
dns: None,
nostr: None,
};
let db = Arc::new(MockDb::default());
let node = Arc::new(MockNode::default());
let rates = Arc::new(DefaultRateCache::default());
let provisioner = LNVpsProvisioner::new(settings, db, node, rates);
}
}