refactor: make services generic for mock testing

This commit is contained in:
2025-02-28 15:51:09 +00:00
parent f105b90849
commit e84e6afd54
17 changed files with 605 additions and 353 deletions

View File

@ -5,7 +5,6 @@ use lnvps_db::VmHostRegion;
use nostr::util::hex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::str::FromStr;
#[derive(Serialize, Deserialize, JsonSchema)]

View File

@ -1,12 +1,10 @@
use anyhow::Error;
use clap::Parser;
use config::{Config, File};
use fedimint_tonic_lnd::connect;
use lnvps::api;
use lnvps::cors::CORS;
use lnvps::exchange::ExchangeRateCache;
use lnvps::exchange::{DefaultRateCache, ExchangeRateService};
use lnvps::invoice::InvoiceHandler;
use lnvps::provisioner::Provisioner;
use lnvps::settings::Settings;
use lnvps::status::VmStateCache;
use lnvps::worker::{WorkJob, Worker};
@ -16,8 +14,10 @@ use nostr::Keys;
use nostr_sdk::Client;
use rocket_okapi::swagger_ui::{make_swagger_ui, SwaggerUIConfig};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
use lnvps::lightning::get_node;
#[derive(Parser)]
#[clap(about, version, author)]
@ -38,16 +38,8 @@ async fn main() -> Result<(), Error> {
.build()?
.try_deserialize()?;
let db = LNVpsDbMysql::new(&settings.db).await?;
let db = Arc::new(LNVpsDbMysql::new(&settings.db).await?);
db.migrate().await?;
let exchange = ExchangeRateCache::new();
let lnd = connect(
settings.lnd.url.clone(),
settings.lnd.cert.clone(),
settings.lnd.macaroon.clone(),
)
.await?;
#[cfg(debug_assertions)]
{
let setup_script = include_str!("../../dev_setup.sql");
@ -64,15 +56,18 @@ async fn main() -> Result<(), Error> {
} else {
None
};
let router = settings.get_router()?;
let exchange: Arc<dyn ExchangeRateService> = Arc::new(DefaultRateCache::default());
let node = get_node(&settings).await?;
let status = VmStateCache::new();
let worker_provisioner =
settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
worker_provisioner.init().await?;
let provisioner =
settings.get_provisioner(db.clone(), node.clone(), exchange.clone());
provisioner.init().await?;
let mut worker = Worker::new(
db.clone(),
worker_provisioner,
provisioner.clone(),
&settings,
status.clone(),
nostr_client.clone(),
@ -95,7 +90,7 @@ async fn main() -> Result<(), Error> {
}
}
});
let mut handler = InvoiceHandler::new(lnd.clone(), db.clone(), sender.clone());
let mut handler = InvoiceHandler::new(node.clone(), db.clone(), sender.clone());
tokio::spawn(async move {
loop {
if let Err(e) = handler.listen().await {
@ -130,12 +125,6 @@ async fn main() -> Result<(), Error> {
}
});
let router = settings.get_router()?;
let provisioner = settings.get_provisioner(db.clone(), router, lnd.clone(), exchange.clone());
let db: Box<dyn LNVpsDb> = Box::new(db.clone());
let pv: Box<dyn Provisioner> = Box::new(provisioner);
let mut config = rocket::Config::default();
let ip: SocketAddr = match &settings.listen {
Some(i) => i.parse()?,
@ -147,7 +136,7 @@ async fn main() -> Result<(), Error> {
if let Err(e) = rocket::Rocket::custom(config)
.attach(CORS)
.manage(db)
.manage(pv)
.manage(provisioner)
.manage(status)
.manage(exchange)
.manage(sender)

View File

@ -1,4 +1,5 @@
use anyhow::{Error, Result};
use lnvps_db::async_trait;
use log::info;
use rocket::serde::Deserialize;
use std::collections::HashMap;
@ -56,33 +57,21 @@ impl Display for Ticker {
#[derive(Debug, PartialEq)]
pub struct TickerRate(pub Ticker, pub f32);
#[derive(Clone)]
pub struct ExchangeRateCache {
#[async_trait]
pub trait ExchangeRateService: Send + Sync {
async fn fetch_rates(&self) -> Result<Vec<TickerRate>>;
async fn set_rate(&self, ticker: Ticker, amount: f32);
async fn get_rate(&self, ticker: Ticker) -> Option<f32>;
}
#[derive(Clone, Default)]
pub struct DefaultRateCache {
cache: Arc<RwLock<HashMap<Ticker, f32>>>,
}
#[derive(Deserialize)]
struct MempoolRates {
#[serde(rename = "USD")]
pub usd: Option<f32>,
#[serde(rename = "EUR")]
pub eur: Option<f32>,
}
impl Default for ExchangeRateCache {
fn default() -> Self {
Self::new()
}
}
impl ExchangeRateCache {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn fetch_rates(&self) -> Result<Vec<TickerRate>> {
#[async_trait]
impl ExchangeRateService for DefaultRateCache {
async fn fetch_rates(&self) -> Result<Vec<TickerRate>> {
let rsp = reqwest::get("https://mempool.space/api/v1/prices")
.await?
.text()
@ -100,14 +89,22 @@ impl ExchangeRateCache {
Ok(ret)
}
pub async fn set_rate(&self, ticker: Ticker, amount: f32) {
async fn set_rate(&self, ticker: Ticker, amount: f32) {
let mut cache = self.cache.write().await;
info!("{}: {}", &ticker, amount);
cache.insert(ticker, amount);
}
pub async fn get_rate(&self, ticker: Ticker) -> Option<f32> {
async fn get_rate(&self, ticker: Ticker) -> Option<f32> {
let cache = self.cache.read().await;
cache.get(&ticker).cloned()
}
}
#[derive(Deserialize)]
struct MempoolRates {
#[serde(rename = "USD")]
pub usd: Option<f32>,
#[serde(rename = "EUR")]
pub eur: Option<f32>,
}

View File

@ -1,12 +1,22 @@
use crate::host::proxmox::ProxmoxClient;
use anyhow::Result;
use lnvps_db::{VmHost, VmHostKind};
use crate::settings::ProvisionerConfig;
use anyhow::{bail, Result};
use lnvps_db::{async_trait, VmHost, VmHostKind};
pub mod proxmox;
pub trait VmHostClient {}
pub fn get_host_client(host: &VmHost) -> Result<ProxmoxClient> {
Ok(match host.kind {
VmHostKind::Proxmox => ProxmoxClient::new(host.ip.parse()?).with_api_token(&host.api_token),
/// Generic type for creating VM's
#[async_trait]
pub trait VmHostClient {
}
pub fn get_host_client(host: &VmHost, cfg: &ProvisionerConfig) -> Result<ProxmoxClient> {
Ok(match (host.kind.clone(), &cfg) {
(VmHostKind::Proxmox, ProvisionerConfig::Proxmox { qemu, ssh, .. }) => {
ProxmoxClient::new(host.ip.parse()?, qemu.clone(), ssh.clone())
.with_api_token(&host.api_token)
}
_ => bail!("Unsupported host type"),
})
}

View File

@ -1,10 +1,17 @@
use crate::settings::{QemuConfig, SshConfig};
use crate::ssh_client::SshClient;
use anyhow::{anyhow, bail, Result};
use log::debug;
use ipnetwork::IpNetwork;
use lnvps_db::{IpRange, LNVpsDb, Vm, VmIpAssignment};
use log::{debug, info};
use nostr_sdk::async_utility::futures_util::future::join_all;
use reqwest::{ClientBuilder, Method, Url};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::sleep;
@ -15,10 +22,12 @@ pub struct ProxmoxClient {
base: Url,
token: String,
client: reqwest::Client,
config: QemuConfig,
ssh: Option<SshConfig>,
}
impl ProxmoxClient {
pub fn new(base: Url) -> Self {
pub fn new(base: Url, config: QemuConfig, ssh: Option<SshConfig>) -> Self {
let client = ClientBuilder::new()
.danger_accept_invalid_certs(true)
.build()
@ -28,9 +37,84 @@ impl ProxmoxClient {
base,
token: String::new(),
client,
config,
ssh,
}
}
/// Create [VmConfig] for a given VM and list of IPs
pub async fn make_vm_config(
&self,
db: &Arc<dyn LNVpsDb>,
vm: &Vm,
ips: &Vec<VmIpAssignment>,
) -> Result<VmConfig> {
let ssh_key = db.get_user_ssh_key(vm.ssh_key_id).await?;
let ip_range_ids: HashSet<u64> = ips.iter().map(|i| i.ip_range_id).collect();
let ip_ranges: Vec<_> = ip_range_ids.iter().map(|i| db.get_ip_range(*i)).collect();
let ip_ranges: HashMap<u64, IpRange> = join_all(ip_ranges)
.await
.into_iter()
.filter_map(Result::ok)
.map(|i| (i.id, i))
.collect();
let mut ip_config = ips
.iter()
.map_while(|ip| {
if let Ok(net) = ip.ip.parse::<IpNetwork>() {
Some(match net {
IpNetwork::V4(addr) => {
let range = ip_ranges.get(&ip.ip_range_id)?;
format!("ip={},gw={}", addr, range.gateway)
}
IpNetwork::V6(addr) => format!("ip6={}", addr),
})
} else {
None
}
})
.collect::<Vec<_>>();
ip_config.push("ip6=auto".to_string());
let mut net = vec![
format!("virtio={}", vm.mac_address),
format!("bridge={}", self.config.bridge),
];
if let Some(t) = self.config.vlan {
net.push(format!("tag={}", t));
}
let drives = db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
let template = db.get_vm_template(vm.template_id).await?;
Ok(VmConfig {
cpu: Some(self.config.cpu.clone()),
kvm: Some(self.config.kvm),
ip_config: Some(ip_config.join(",")),
machine: Some(self.config.machine.clone()),
net: Some(net.join(",")),
os_type: Some(self.config.os_type.clone()),
on_boot: Some(true),
bios: Some(VmBios::OVMF),
boot: Some("order=scsi0".to_string()),
cores: Some(template.cpu as i32),
memory: Some((template.memory / 1024 / 1024).to_string()),
scsi_hw: Some("virtio-scsi-pci".to_string()),
serial_0: Some("socket".to_string()),
scsi_1: Some(format!("{}:cloudinit", &drive.name)),
ssh_keys: Some(urlencoding::encode(&ssh_key.key_data).to_string()),
efi_disk_0: Some(format!("{}:0,efitype=4m", &drive.name)),
..Default::default()
})
}
pub fn with_api_token(mut self, token: &str) -> Self {
// PVEAPIToken=USER@REALM!TOKENID=UUID
self.token = token.to_string();
@ -185,6 +269,54 @@ impl ProxmoxClient {
})
}
pub async fn import_disk_image(&self, req: ImportDiskImageRequest) -> Result<()> {
// import the disk
// TODO: find a way to avoid using SSH
if let Some(ssh_config) = &self.ssh {
let mut ses = SshClient::new()?;
ses.connect(
(self.base.host().unwrap().to_string(), 22),
&ssh_config.user,
&ssh_config.key,
)
.await?;
// Disk import args
let mut disk_args: HashMap<&str, String> = HashMap::new();
disk_args.insert(
"import-from",
format!("/var/lib/vz/template/iso/{}", req.image),
);
// If disk is SSD, enable discard + ssd options
if req.is_ssd {
disk_args.insert("discard", "on".to_string());
disk_args.insert("ssd", "1".to_string());
}
let cmd = format!(
"/usr/sbin/qm set {} --{} {}:0,{}",
req.vm_id,
&req.disk,
&req.storage,
disk_args
.into_iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(",")
);
let (code, rsp) = ses.execute(cmd.as_str()).await?;
info!("{}", rsp);
if code != 0 {
bail!("Failed to import disk, exit-code {}, {}", code, rsp);
}
Ok(())
} else {
bail!("Cannot complete, no method available to import disk, consider configuring ssh")
}
}
/// Resize a disk on a VM
pub async fn resize_disk(&self, req: ResizeDiskRequest) -> Result<TaskId> {
let rsp: ResponseBase<String> = self
@ -571,6 +703,22 @@ pub struct ResizeDiskRequest {
pub size: String,
}
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct ImportDiskImageRequest {
/// VM id
pub vm_id: i32,
/// Node name
pub node: String,
/// Storage pool to import disk to
pub storage: String,
/// Disk name (scsi0 etc)
pub disk: String,
/// Image filename on disk inside the disk storage dir
pub image: String,
/// If the disk is an SSD and discard should be enabled
pub is_ssd: bool,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum VmBios {

View File

@ -1,27 +1,26 @@
use crate::lightning::{InvoiceUpdate, LightningNode};
use crate::worker::WorkJob;
use anyhow::Result;
use fedimint_tonic_lnd::lnrpc::invoice::InvoiceState;
use fedimint_tonic_lnd::lnrpc::InvoiceSubscription;
use fedimint_tonic_lnd::Client;
use lnvps_db::LNVpsDb;
use log::{error, info};
use log::{error, info, warn};
use nostr::util::hex;
use rocket::futures::StreamExt;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
pub struct InvoiceHandler {
lnd: Client,
db: Box<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
db: Arc<dyn LNVpsDb>,
tx: UnboundedSender<WorkJob>,
}
impl InvoiceHandler {
pub fn new<D: LNVpsDb + 'static>(lnd: Client, db: D, tx: UnboundedSender<WorkJob>) -> Self {
Self {
lnd,
tx,
db: Box::new(db),
}
pub fn new(
node: Arc<dyn LightningNode>,
db: Arc<dyn LNVpsDb>,
tx: UnboundedSender<WorkJob>,
) -> Self {
Self { node, tx, db }
}
async fn mark_paid(&self, settle_index: u64, id: &Vec<u8>) -> Result<()> {
@ -36,33 +35,28 @@ impl InvoiceHandler {
}
pub async fn listen(&mut self) -> Result<()> {
let from_settle_index = if let Some(p) = self.db.last_paid_invoice().await? {
p.settle_index.unwrap_or(0)
} else {
0
};
info!("Listening for invoices from {from_settle_index}");
let from_ph = self.db.last_paid_invoice().await?.map(|i| i.id.clone());
info!(
"Listening for invoices from {}",
from_ph
.as_ref()
.map(hex::encode)
.unwrap_or("NOW".to_string())
);
let handler = self
.lnd
.lightning()
.subscribe_invoices(InvoiceSubscription {
add_index: 0,
settle_index: from_settle_index,
})
.await?;
let mut stream = handler.into_inner();
while let Some(msg) = stream.next().await {
let mut handler = self.node.subscribe_invoices(from_ph).await?;
while let Some(msg) = handler.next().await {
match msg {
Ok(i) => {
if i.state == InvoiceState::Settled as i32 {
if let Err(e) = self.mark_paid(i.settle_index, &i.r_hash).await {
error!("{}", e);
}
InvoiceUpdate::Settled {
payment_hash,
settle_index,
} => {
let r_hash = hex::decode(payment_hash)?;
if let Err(e) = self.mark_paid(settle_index, &r_hash).await {
error!("{}", e);
}
}
Err(e) => error!("{}", e),
v => warn!("Unknown invoice update: {:?}", v),
}
}
Ok(())

View File

@ -10,6 +10,7 @@ pub mod settings;
pub mod ssh_client;
pub mod status;
pub mod worker;
pub mod lightning;
#[cfg(test)]
pub mod mocks;

96
src/lightning/lnd.rs Normal file
View File

@ -0,0 +1,96 @@
use crate::lightning::{AddInvoiceRequest, AddInvoiceResult, InvoiceUpdate, LightningNode};
use crate::settings::LndConfig;
use anyhow::Result;
use fedimint_tonic_lnd::invoicesrpc::lookup_invoice_msg::InvoiceRef;
use fedimint_tonic_lnd::invoicesrpc::LookupInvoiceMsg;
use fedimint_tonic_lnd::lnrpc::invoice::InvoiceState;
use fedimint_tonic_lnd::lnrpc::{Invoice, InvoiceSubscription};
use fedimint_tonic_lnd::{connect, Client};
use futures::StreamExt;
use lnvps_db::async_trait;
use nostr_sdk::async_utility::futures_util::Stream;
use std::pin::Pin;
pub struct LndNode {
client: Client,
}
impl LndNode {
pub async fn new(settings: &LndConfig) -> Result<Self> {
let lnd = connect(
settings.url.clone(),
settings.cert.clone(),
settings.macaroon.clone(),
)
.await?;
Ok(Self { client: lnd })
}
}
#[async_trait]
impl LightningNode for LndNode {
async fn add_invoice(&self, req: AddInvoiceRequest) -> Result<AddInvoiceResult> {
let mut client = self.client.clone();
let ln = client.lightning();
let res = ln
.add_invoice(Invoice {
memo: req.memo.unwrap_or_default(),
value_msat: req.amount as i64,
expiry: req.expire.unwrap_or(3600) as i64,
..Default::default()
})
.await?;
let inner = res.into_inner();
Ok(AddInvoiceResult {
pr: inner.payment_request,
payment_hash: hex::encode(inner.r_hash),
})
}
async fn subscribe_invoices(
&self,
from_payment_hash: Option<Vec<u8>>,
) -> Result<Pin<Box<dyn Stream<Item = InvoiceUpdate> + Send>>> {
let mut client = self.client.clone();
let from_settle_index = if let Some(ph) = from_payment_hash {
if let Ok(inv) = client
.invoices()
.lookup_invoice_v2(LookupInvoiceMsg {
lookup_modifier: 0,
invoice_ref: Some(InvoiceRef::PaymentHash(ph)),
})
.await
{
inv.into_inner().settle_index
} else {
0
}
} else {
0
};
let stream = client
.lightning()
.subscribe_invoices(InvoiceSubscription {
add_index: 0,
settle_index: from_settle_index,
})
.await?;
let stream = stream.into_inner();
Ok(Box::pin(stream.map(|i| match i {
Ok(m) => {
if m.state == InvoiceState::Settled as i32 {
InvoiceUpdate::Settled {
settle_index: m.settle_index,
payment_hash: hex::encode(m.r_hash),
}
} else {
InvoiceUpdate::Unknown
}
}
Err(e) => InvoiceUpdate::Error(e.to_string()),
})))
}
}

47
src/lightning/mod.rs Normal file
View File

@ -0,0 +1,47 @@
use crate::lightning::lnd::LndNode;
use crate::settings::Settings;
use anyhow::Result;
use futures::Stream;
use lnvps_db::async_trait;
use std::pin::Pin;
use std::sync::Arc;
mod lnd;
/// Generic lightning node for creating payments
#[async_trait]
pub trait LightningNode: Send + Sync {
async fn add_invoice(&self, req: AddInvoiceRequest) -> Result<AddInvoiceResult>;
async fn subscribe_invoices(
&self,
from_payment_hash: Option<Vec<u8>>,
) -> Result<Pin<Box<dyn Stream<Item = InvoiceUpdate> + Send>>>;
}
#[derive(Debug, Clone)]
pub struct AddInvoiceRequest {
pub amount: u64,
pub memo: Option<String>,
pub expire: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct AddInvoiceResult {
pub pr: String,
pub payment_hash: String,
}
#[derive(Debug, Clone)]
pub enum InvoiceUpdate {
/// Internal impl created an update which we don't support or care about
Unknown,
Error(String),
Settled {
payment_hash: String,
settle_index: u64,
},
}
pub async fn get_node(settings: &Settings) -> Result<Arc<dyn LightningNode>> {
Ok(Arc::new(LndNode::new(&settings.lnd).await?))
}

View File

@ -1,6 +1,6 @@
use crate::router::{ArpEntry, Router};
use crate::settings::NetworkPolicy;
use anyhow::anyhow;
use anyhow::{anyhow, bail};
use chrono::Utc;
use lnvps_db::{
async_trait, IpRange, LNVpsDb, User, UserSshKey, Vm, VmCostPlan, VmHost, VmHostDisk,
@ -10,6 +10,8 @@ use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::exchange::{ExchangeRateService, Ticker, TickerRate};
use crate::lightning::LightningNode;
#[derive(Debug, Clone)]
pub struct MockDb {
@ -313,12 +315,14 @@ impl LNVpsDb for MockDb {
struct MockRouter {
pub policy: NetworkPolicy,
pub arp: Arc<Mutex<HashMap<String, ArpEntry>>>,
}
#[async_trait]
impl Router for MockRouter {
async fn list_arp_entry(&self) -> anyhow::Result<Vec<ArpEntry>> {
todo!()
let arp = self.arp.lock().await;
Ok(arp.values().cloned().collect())
}
async fn add_arp_entry(
@ -328,10 +332,34 @@ impl Router for MockRouter {
interface: &str,
comment: Option<&str>,
) -> anyhow::Result<()> {
todo!()
let mut arp = self.arp.lock().await;
if arp.iter().any(|(k, v)| v.address == ip.to_string()) {
bail!("Address is already in use");
}
arp.insert(
mac.to_string(),
ArpEntry {
id: Some(mac.to_string()),
address: ip.to_string(),
mac_address: None,
interface: interface.to_string(),
comment: comment.map(|s| s.to_string()),
},
);
Ok(())
}
async fn remove_arp_entry(&self, id: &str) -> anyhow::Result<()> {
todo!()
let mut arp = self.arp.lock().await;
arp.remove(id);
Ok(())
}
}
#[derive(Clone, Debug, Default)]
pub struct MockNode {
}
#[async_trait]
impl LightningNode for MockNode {}

View File

@ -1,28 +1,23 @@
use crate::exchange::{ExchangeRateCache, Ticker};
use crate::exchange::{ExchangeRateService, Ticker};
use crate::host::get_host_client;
use crate::host::proxmox::{
ConfigureVm, CreateVm, DownloadUrlRequest, ProxmoxClient, ResizeDiskRequest, StorageContent,
VmBios, VmConfig,
ConfigureVm, CreateVm, DownloadUrlRequest, ImportDiskImageRequest, ProxmoxClient,
ResizeDiskRequest, StorageContent, VmConfig,
};
use crate::lightning::{AddInvoiceRequest, LightningNode};
use crate::provisioner::{NetworkProvisioner, Provisioner, ProvisionerMethod};
use crate::router::Router;
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, QemuConfig, SshConfig};
use crate::ssh_client::SshClient;
use crate::settings::{
NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings,
};
use anyhow::{bail, Result};
use chrono::{Days, Months, Utc};
use fedimint_tonic_lnd::lnrpc::Invoice;
use fedimint_tonic_lnd::tonic::async_trait;
use fedimint_tonic_lnd::Client;
use ipnetwork::IpNetwork;
use lnvps_db::{DiskType, IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment};
use lnvps_db::{DiskType, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment};
use log::{debug, info, warn};
use nostr::util::hex;
use rand::random;
use rand::seq::IteratorRandom;
use reqwest::Url;
use rocket::futures::future::join_all;
use rocket::futures::{SinkExt, StreamExt};
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::ops::Add;
use std::str::FromStr;
@ -32,37 +27,38 @@ use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
/// Main provisioner class for LNVPS
///
/// Does all the hard work and logic for creating / expiring VM's
pub struct LNVpsProvisioner {
db: Arc<dyn LNVpsDb>,
router: Option<Box<dyn Router>>,
lnd: Client,
rates: ExchangeRateCache,
read_only: bool,
config: QemuConfig,
db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
rates: Arc<dyn ExchangeRateService>,
router: Option<Arc<dyn Router>>,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>,
provisioner_config: ProvisionerConfig,
}
impl LNVpsProvisioner {
pub fn new(
read_only: bool,
config: QemuConfig,
network_policy: NetworkPolicy,
ssh: Option<SshConfig>,
db: impl LNVpsDb + 'static,
router: Option<impl Router + 'static>,
lnd: Client,
rates: ExchangeRateCache,
settings: Settings,
db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
rates: Arc<dyn ExchangeRateService>,
) -> Self {
Self {
db: Arc::new(db),
router: router.map(|r| Box::new(r) as Box<dyn Router>),
lnd,
network_policy,
db,
node,
rates,
config,
read_only,
ssh,
router: settings
.get_router()
.expect("router config"),
network_policy: settings.network_policy,
provisioner_config: settings.provisioner,
read_only: settings.read_only,
}
}
@ -78,98 +74,24 @@ impl LNVpsProvisioner {
}
}
async fn get_vm_config(&self, vm: &Vm, ips: &Vec<VmIpAssignment>) -> Result<VmConfig> {
let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?;
let ip_range_ids: HashSet<u64> = ips.iter().map(|i| i.ip_range_id).collect();
let ip_ranges: Vec<_> = ip_range_ids
.iter()
.map(|i| self.db.get_ip_range(*i))
.collect();
let ip_ranges: HashMap<u64, IpRange> = join_all(ip_ranges)
.await
.into_iter()
.filter_map(Result::ok)
.map(|i| (i.id, i))
.collect();
let mut ip_config = ips
.iter()
.map_while(|ip| {
if let Ok(net) = ip.ip.parse::<IpNetwork>() {
Some(match net {
IpNetwork::V4(addr) => {
let range = ip_ranges.get(&ip.ip_range_id)?;
format!("ip={},gw={}", addr, range.gateway)
}
IpNetwork::V6(addr) => format!("ip6={}", addr),
})
} else {
None
}
})
.collect::<Vec<_>>();
ip_config.push("ip6=auto".to_string());
let mut net = vec![
format!("virtio={}", vm.mac_address),
format!("bridge={}", self.config.bridge),
];
if let Some(t) = self.config.vlan {
net.push(format!("tag={}", t));
}
let drives = self.db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
let template = self.db.get_vm_template(vm.template_id).await?;
Ok(VmConfig {
cpu: Some(self.config.cpu.clone()),
kvm: Some(self.config.kvm),
ip_config: Some(ip_config.join(",")),
machine: Some(self.config.machine.clone()),
net: Some(net.join(",")),
os_type: Some(self.config.os_type.clone()),
on_boot: Some(true),
bios: Some(VmBios::OVMF),
boot: Some("order=scsi0".to_string()),
cores: Some(template.cpu as i32),
memory: Some((template.memory / 1024 / 1024).to_string()),
scsi_hw: Some("virtio-scsi-pci".to_string()),
serial_0: Some("socket".to_string()),
scsi_1: Some(format!("{}:cloudinit", &drive.name)),
ssh_keys: Some(urlencoding::encode(&ssh_key.key_data).to_string()),
efi_disk_0: Some(format!("{}:0,efitype=4m", &drive.name)),
..Default::default()
})
}
async fn save_ip_assignment(&self, vm: &Vm, assignment: &VmIpAssignment) -> Result<()> {
// apply network policy
match &self.network_policy.access {
NetworkAccessPolicy::StaticArp { interface } => {
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(
IpAddr::from_str(&assignment.ip)?,
&vm.mac_address,
interface,
Some(&format!("VM{}", vm.id)),
)
.await?;
} else {
bail!("No router found to apply static arp entry!")
}
if let NetworkAccessPolicy::StaticArp { interface } = &self.network_policy.access {
if let Some(r) = self.router.as_ref() {
r.add_arp_entry(
IpAddr::from_str(&assignment.ip)?,
&vm.mac_address,
interface,
Some(&format!("VM{}", vm.id)),
)
.await?;
} else {
bail!("No router found to apply static arp entry!")
}
_ => {}
}
// save to db
self.db.insert_vm_ip_assignment(&assignment).await?;
self.db.insert_vm_ip_assignment(assignment).await?;
Ok(())
}
}
@ -180,7 +102,7 @@ impl Provisioner for LNVpsProvisioner {
// tell hosts to download images
let hosts = self.db.list_hosts().await?;
for host in hosts {
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let iso_storage = Self::get_iso_storage(&host.name, &client).await?;
let files = client.list_storage_files(&host.name, &iso_storage).await?;
@ -284,7 +206,7 @@ impl Provisioner for LNVpsProvisioner {
};
const BTC_SATS: f64 = 100_000_000.0;
const INVOICE_EXPIRE: i64 = 3600;
const INVOICE_EXPIRE: u32 = 3600;
let ticker = Ticker::btc_rate(cost_plan.currency.as_str())?;
let rate = if let Some(r) = self.rates.get_rate(ticker).await {
@ -294,27 +216,23 @@ impl Provisioner for LNVpsProvisioner {
};
let cost_btc = cost_plan.amount as f32 / rate;
let cost_msat = (cost_btc as f64 * BTC_SATS) as i64 * 1000;
let cost_msat = (cost_btc as f64 * BTC_SATS) as u64 * 1000;
info!("Creating invoice for {vm_id} for {} sats", cost_msat / 1000);
let mut lnd = self.lnd.clone();
let invoice = lnd
.lightning()
.add_invoice(Invoice {
memo: format!("VM renewal {vm_id} to {new_expire}"),
value_msat: cost_msat,
expiry: INVOICE_EXPIRE,
..Default::default()
let invoice = self
.node
.add_invoice(AddInvoiceRequest {
memo: Some(format!("VM renewal {vm_id} to {new_expire}")),
amount: cost_msat,
expire: Some(INVOICE_EXPIRE),
})
.await?;
let invoice = invoice.into_inner();
let vm_payment = VmPayment {
id: invoice.r_hash.clone(),
id: hex::decode(invoice.payment_hash)?,
vm_id,
created: Utc::now(),
expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE as u64)),
amount: cost_msat as u64,
invoice: invoice.payment_request.clone(),
amount: cost_msat,
invoice: invoice.pr,
time_value: (new_expire - vm.expires).num_seconds() as u64,
is_paid: false,
rate,
@ -332,13 +250,13 @@ impl Provisioner for LNVpsProvisioner {
return Ok(existing_ips);
}
let template = self.db.get_vm_template(vm.template_id).await?;
// Use random network provisioner
let prov = NetworkProvisioner::new(
ProvisionerMethod::Random,
self.network_policy.clone(),
self.db.clone(),
);
let template = self.db.get_vm_template(vm.template_id).await?;
let ip = prov.pick_ip_for_region(template.region_id).await?;
let assignment = VmIpAssignment {
id: 0,
@ -352,21 +270,25 @@ impl Provisioner for LNVpsProvisioner {
Ok(vec![assignment])
}
/// Create a vm on the host as configured by the template
async fn spawn_vm(&self, vm_id: u64) -> Result<()> {
if self.read_only {
bail!("Cant spawn VM's in read-only mode");
bail!("Cant spawn VM's in read-only mode")
}
let vm = self.db.get_vm(vm_id).await?;
let template = self.db.get_vm_template(vm.template_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let image = self.db.get_os_image(vm.image_id).await?;
// TODO: remove +100 nonsense (proxmox specific)
let vm_id = 100 + vm.id as i32;
// setup network by allocating some IP space
let ips = self.allocate_ips(vm.id).await?;
// create VM
let config = self.get_vm_config(&vm, &ips).await?;
let config = client.make_vm_config(&self.db, &vm, &ips).await?;
let t_create = client
.create_vm(CreateVm {
node: host.name.clone(),
@ -376,60 +298,31 @@ impl Provisioner for LNVpsProvisioner {
.await?;
client.wait_for_task(&t_create).await?;
// save
// import the disk
// TODO: find a way to avoid using SSH
if let Some(ssh_config) = &self.ssh {
let image = self.db.get_os_image(vm.image_id).await?;
let host_addr: Url = host.ip.parse()?;
let mut ses = SshClient::new()?;
ses.connect(
(host_addr.host().unwrap().to_string(), 22),
&ssh_config.user,
&ssh_config.key,
)
// TODO: pick disk based on available space
// TODO: build external module to manage picking disks
// pick disk
let drives = self.db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
// TODO: remove scsi0 terms (proxmox specific)
// import primary disk from image (scsi0)?
client
.import_disk_image(ImportDiskImageRequest {
vm_id,
node: host.name.clone(),
storage: drive.name.clone(),
disk: "scsi0".to_string(),
image: image.filename()?,
is_ssd: matches!(drive.kind, DiskType::SSD),
})
.await?;
let drives = self.db.list_host_disks(vm.host_id).await?;
let drive = if let Some(d) = drives.iter().find(|d| d.enabled) {
d
} else {
bail!("No host drive found!")
};
// Disk import args
let mut disk_args: HashMap<&str, String> = HashMap::new();
disk_args.insert(
"import-from",
format!("/var/lib/vz/template/iso/{}", image.filename()?),
);
// If disk is SSD, enable discard + ssd options
if matches!(drive.kind, DiskType::SSD) {
disk_args.insert("discard", "on".to_string());
disk_args.insert("ssd", "1".to_string());
}
let cmd = format!(
"/usr/sbin/qm set {} --scsi0 {}:0,{}",
vm_id,
&drive.name,
disk_args
.into_iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join(",")
);
let (code, rsp) = ses.execute(cmd.as_str()).await?;
info!("{}", rsp);
if code != 0 {
bail!("Failed to import disk, exit-code {}", code);
}
} else {
bail!("Cannot complete, no method available to import disk, consider configuring ssh")
}
// resize disk
// TODO: remove scsi0 terms (proxmox specific)
// resize disk to match template
let j_resize = client
.resize_disk(ResizeDiskRequest {
node: host.name.clone(),
@ -452,7 +345,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let j_start = client.start_vm(&host.name, vm.id + 100).await?;
client.wait_for_task(&j_start).await?;
Ok(())
@ -462,7 +355,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let j_start = client.shutdown_vm(&host.name, vm.id + 100).await?;
client.wait_for_task(&j_start).await?;
@ -473,7 +366,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let j_start = client.reset_vm(&host.name, vm.id + 100).await?;
client.wait_for_task(&j_start).await?;
@ -517,7 +410,7 @@ impl Provisioner for LNVpsProvisioner {
) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let host_vm_id = vm.id + 100;
let term = client.terminal_proxy(&host.name, host_vm_id).await?;
@ -541,7 +434,7 @@ impl Provisioner for LNVpsProvisioner {
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let ips = self.db.list_vm_ip_assignments(vm.id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.provisioner_config)?;
let host_vm_id = vm.id + 100;
let t = client
@ -554,7 +447,7 @@ impl Provisioner for LNVpsProvisioner {
scsi_0: None,
scsi_1: None,
efi_disk_0: None,
..self.get_vm_config(&vm, &ips).await?
..client.make_vm_config(&self.db, &vm, &ips).await?
},
})
.await?;
@ -562,3 +455,50 @@ impl Provisioner for LNVpsProvisioner {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exchange::DefaultRateCache;
use crate::mocks::{MockDb, MockNode};
use crate::settings::{LndConfig, ProvisionerConfig};
#[tokio::test]
async fn test_basic_provisioner() {
let settings = Settings {
listen: None,
db: "".to_string(),
lnd: LndConfig {
url: "".to_string(),
cert: Default::default(),
macaroon: Default::default(),
},
read_only: false,
provisioner: ProvisionerConfig::Proxmox {
qemu: QemuConfig {
machine: "q35".to_string(),
os_type: "linux26".to_string(),
bridge: "vmbr1".to_string(),
cpu: "kvm64".to_string(),
vlan: None,
kvm: false,
},
ssh: None,
},
network_policy: NetworkPolicy {
access: NetworkAccessPolicy::StaticArp {
interface: "bridge1".to_string(),
},
},
delete_after: 0,
smtp: None,
router: None,
dns: None,
nostr: None,
};
let db = Arc::new(MockDb::default());
let node = Arc::new(MockNode::default());
let rates = Arc::new(DefaultRateCache::default());
let provisioner = LNVpsProvisioner::new(settings, db, node, rates);
}
}

View File

@ -20,18 +20,17 @@ pub struct AvailableIp {
pub region_id: u64,
}
/// Handles picking available IPs
#[derive(Clone)]
pub struct NetworkProvisioner {
method: ProvisionerMethod,
settings: NetworkPolicy,
db: Arc<dyn LNVpsDb>,
}
impl NetworkProvisioner {
pub fn new(method: ProvisionerMethod, settings: NetworkPolicy, db: Arc<dyn LNVpsDb>) -> Self {
pub fn new(method: ProvisionerMethod, db: Arc<dyn LNVpsDb>) -> Self {
Self {
method,
settings,
db,
}
}
@ -97,9 +96,6 @@ mod tests {
let db: Arc<dyn LNVpsDb> = Arc::new(MockDb::default());
let mgr = NetworkProvisioner::new(
ProvisionerMethod::Sequential,
NetworkPolicy {
access: NetworkAccessPolicy::Auto,
},
db.clone(),
);
@ -130,9 +126,6 @@ mod tests {
let db: Arc<dyn LNVpsDb> = Arc::new(MockDb::default());
let mgr = NetworkProvisioner::new(
ProvisionerMethod::Random,
NetworkPolicy {
access: NetworkAccessPolicy::Auto,
},
db,
);

View File

@ -1,12 +1,13 @@
use crate::exchange::ExchangeRateCache;
use crate::exchange::ExchangeRateService;
use crate::lightning::LightningNode;
use crate::provisioner::LNVpsProvisioner;
use crate::provisioner::Provisioner;
use crate::router::{MikrotikRouter, Router};
use anyhow::{bail, Result};
use fedimint_tonic_lnd::Client;
use lnvps_db::LNVpsDb;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
@ -15,6 +16,9 @@ pub struct Settings {
pub db: String,
pub lnd: LndConfig,
/// Readonly mode, don't spawn any VM's
pub read_only: bool,
/// Provisioning profiles
pub provisioner: ProvisionerConfig,
@ -60,6 +64,7 @@ pub enum RouterConfig {
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum DnsServerConfig {
#[serde(rename_all = "kebab-case")]
Cloudflare { api: ApiConfig, zone_id: String },
}
@ -124,8 +129,6 @@ pub struct SmtpConfig {
pub enum ProvisionerConfig {
#[serde(rename_all = "kebab-case")]
Proxmox {
/// Readonly mode, don't spawn any VM's
read_only: bool,
/// Generic VM configuration
qemu: QemuConfig,
/// SSH config for issuing commands via CLI
@ -161,34 +164,19 @@ pub struct QemuConfig {
impl Settings {
pub fn get_provisioner(
&self,
db: impl LNVpsDb + 'static,
router: Option<impl Router + 'static>,
lnd: Client,
exchange: ExchangeRateCache,
) -> impl Provisioner + 'static {
match &self.provisioner {
ProvisionerConfig::Proxmox {
qemu,
ssh,
read_only,
} => LNVpsProvisioner::new(
*read_only,
qemu.clone(),
self.network_policy.clone(),
ssh.clone(),
db,
router,
lnd,
exchange,
),
}
db: Arc<dyn LNVpsDb>,
node: Arc<dyn LightningNode>,
exchange: Arc<dyn ExchangeRateService>,
) -> Arc<dyn Provisioner> {
Arc::new(LNVpsProvisioner::new(self.clone(), db, node, exchange))
}
pub fn get_router<'a>(&'a self) -> Result<Option<impl Router + 'static>> {
pub fn get_router(&self) -> Result<Option<Arc<dyn Router>>> {
match &self.router {
Some(RouterConfig::Mikrotik(api)) => match &api.credentials {
Credentials::UsernamePassword { username, password } => {
Ok(Some(MikrotikRouter::new(&api.url, username, password)))
}
Credentials::UsernamePassword { username, password } => Ok(Some(Arc::new(
MikrotikRouter::new(&api.url, username, password),
))),
_ => bail!("Only username/password is supported for Mikrotik routers"),
},
_ => Ok(None),

View File

@ -1,7 +1,7 @@
use crate::host::get_host_client;
use crate::host::proxmox::{VmInfo, VmStatus};
use crate::provisioner::Provisioner;
use crate::settings::{Settings, SmtpConfig};
use crate::settings::{ProvisionerConfig, Settings, SmtpConfig};
use crate::status::{VmRunningState, VmState, VmStateCache};
use anyhow::Result;
use chrono::{DateTime, Datelike, Days, Utc};
@ -14,6 +14,7 @@ use log::{debug, error, info};
use nostr::{EventBuilder, PublicKey, ToBech32};
use nostr_sdk::Client;
use std::ops::{Add, Sub};
use std::sync::Arc;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
#[derive(Debug)]
@ -34,18 +35,21 @@ pub enum WorkJob {
pub struct Worker {
settings: WorkerSettings,
db: Box<dyn LNVpsDb>,
provisioner: Box<dyn Provisioner>,
db: Arc<dyn LNVpsDb>,
provisioner: Arc<dyn Provisioner>,
nostr: Option<Client>,
vm_state_cache: VmStateCache,
tx: UnboundedSender<WorkJob>,
rx: UnboundedReceiver<WorkJob>,
client: Option<Client>,
last_check_vms: DateTime<Utc>,
}
pub struct WorkerSettings {
pub delete_after: u16,
pub smtp: Option<SmtpConfig>,
pub provisioner_config: ProvisionerConfig,
}
impl From<&Settings> for WorkerSettings {
@ -53,27 +57,28 @@ impl From<&Settings> for WorkerSettings {
WorkerSettings {
delete_after: val.delete_after,
smtp: val.smtp.clone(),
provisioner_config: val.provisioner.clone(),
}
}
}
impl Worker {
pub fn new<D: LNVpsDb + Clone + 'static, P: Provisioner + 'static>(
db: D,
provisioner: P,
pub fn new(
db: Arc<dyn LNVpsDb>,
provisioner: Arc<dyn Provisioner>,
settings: impl Into<WorkerSettings>,
vm_state_cache: VmStateCache,
client: Option<Client>,
nostr: Option<Client>,
) -> Self {
let (tx, rx) = unbounded_channel();
Self {
db: Box::new(db),
provisioner: Box::new(provisioner),
db,
provisioner,
vm_state_cache,
nostr,
settings: settings.into(),
tx,
rx,
client,
last_check_vms: Utc::now(),
}
}
@ -159,7 +164,7 @@ impl Worker {
debug!("Checking VM: {}", vm_id);
let vm = self.db.get_vm(vm_id).await?;
let host = self.db.get_host(vm.host_id).await?;
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.settings.provisioner_config)?;
match client.get_vm_status(&host.name, (vm.id + 100) as i32).await {
Ok(s) => self.handle_vm_info(s).await?,
@ -196,7 +201,7 @@ impl Worker {
pub async fn check_vms(&mut self) -> Result<()> {
let hosts = self.db.list_hosts().await?;
for host in hosts {
let client = get_host_client(&host)?;
let client = get_host_client(&host, &self.settings.provisioner_config)?;
for node in client.list_nodes().await? {
debug!("Checking vms for {}", node.name);
@ -233,7 +238,7 @@ impl Worker {
}
// delete vm if not paid (in new state)
if !vm.deleted && vm.expires < Utc::now().sub(Days::new(1)) && state.is_none() {
if vm.expires < Utc::now().sub(Days::new(1)) && state.is_none() {
info!("Deleting unpaid VM {}", vm.id);
self.provisioner.delete_vm(vm.id).await?;
}
@ -281,10 +286,10 @@ impl Worker {
}
}
if user.contact_nip4 {
// send dm
// TODO: send nip4 dm
}
if user.contact_nip17 {
if let Some(c) = self.client.as_ref() {
if let Some(c) = self.nostr.as_ref() {
let sig = c.signer().await?;
let ev = EventBuilder::private_msg(
&sig,