From 8ec143bd6b0523240854229cd4431b450c977cdc Mon Sep 17 00:00:00 2001 From: kieran Date: Tue, 4 Mar 2025 11:03:15 +0000 Subject: [PATCH] fix: impl configure_vm --- Cargo.toml | 2 +- lnvps_db/src/mysql.rs | 4 +-- src/api/mod.rs | 4 +-- src/api/routes.rs | 34 +++---------------- src/api/webhook.rs | 1 - src/dns/cloudflare.rs | 8 +++-- src/dns/mod.rs | 13 ++++++-- src/host/libvirt.rs | 4 +-- src/host/mod.rs | 45 +++++++++++++++++++++---- src/host/proxmox.rs | 40 ++++++++++++++-------- src/lightning/lnd.rs | 2 +- src/lightning/mod.rs | 5 ++- src/mocks.rs | 35 ++++++++++---------- src/provisioner/lnvps.rs | 71 +++++++++++----------------------------- src/router/mikrotik.rs | 48 +++++++++++++++++++++------ src/router/mod.rs | 16 +++------ src/settings.rs | 2 +- 17 files changed, 179 insertions(+), 155 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d62d69f..4ebc396 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" name = "api" [features] -default = ["mikrotik", "nostr-dm", "proxmox", "lnd", "bitvora", "cloudflare"] +default = ["mikrotik", "nostr-dm", "proxmox", "lnd", "cloudflare"] mikrotik = ["dep:reqwest"] nostr-dm = ["dep:nostr-sdk"] proxmox = ["dep:reqwest", "dep:ssh2", "dep:tokio-tungstenite"] diff --git a/lnvps_db/src/mysql.rs b/lnvps_db/src/mysql.rs index c7da376..2670f7e 100644 --- a/lnvps_db/src/mysql.rs +++ b/lnvps_db/src/mysql.rs @@ -68,7 +68,7 @@ impl LNVpsDb for LNVpsDbMysql { Ok(()) } - async fn delete_user(&self, id: u64) -> Result<()> { + async fn delete_user(&self, _id: u64) -> Result<()> { todo!() } @@ -93,7 +93,7 @@ impl LNVpsDb for LNVpsDbMysql { .map_err(Error::new) } - async fn delete_user_ssh_key(&self, id: u64) -> Result<()> { + async fn delete_user_ssh_key(&self, _id: u64) -> Result<()> { todo!() } diff --git a/src/api/mod.rs b/src/api/mod.rs index eb89078..f6e8b80 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,4 @@ -use rocket::{routes, Route}; +use rocket::Route; mod model; mod routes; @@ -10,4 +10,4 @@ pub fn routes() -> Vec { r } -pub use webhook::WEBHOOK_BRIDGE; \ No newline at end of file +pub use webhook::WEBHOOK_BRIDGE; diff --git a/src/api/routes.rs b/src/api/routes.rs index bfac231..894e25e 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -2,18 +2,17 @@ use crate::api::model::{ AccountPatchRequest, ApiUserSshKey, ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus, ApiVmTemplate, CreateSshKey, CreateVmRequest, VMPatchRequest, }; -use crate::host::get_host_client; +use crate::host::{get_host_client, FullVmInfo}; use crate::nip98::Nip98Auth; use crate::provisioner::LNVpsProvisioner; use crate::settings::Settings; use crate::status::{VmState, VmStateCache}; use crate::worker::WorkJob; -use anyhow::{bail, Result}; +use anyhow::Result; use futures::future::join_all; use lnvps_db::{IpRange, LNVpsDb}; -use log::{debug, error}; use nostr::util::hex; -use rocket::futures::{Sink, SinkExt, StreamExt}; +use rocket::futures::{SinkExt, StreamExt}; use rocket::serde::json::Json; use rocket::{get, patch, post, Responder, Route, State}; use rocket_okapi::gen::OpenApiGenerator; @@ -24,10 +23,8 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use ssh_key::PublicKey; use std::collections::{HashMap, HashSet}; -use std::fmt::Display; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; -use ws::Message; pub fn routes() -> Vec { openapi_get_routes![ @@ -229,9 +226,10 @@ async fn v1_patch_vm( db.update_vm(&vm).await?; + let info = FullVmInfo::load(vm.id, (*db).clone()).await?; let host = db.get_host(vm.host_id).await?; let client = get_host_client(&host, &settings.provisioner)?; - client.configure_vm(&vm).await?; + client.configure_vm(&info).await?; ApiData::ok(()) } @@ -480,25 +478,3 @@ async fn v1_get_payment( ApiData::ok(payment.into()) } - -#[get("/api/v1/console/?")] -async fn v1_terminal_proxy( - auth: &str, - db: &State>, - _provisioner: &State>, - id: u64, - _ws: ws::WebSocket, -) -> Result, &'static str> { - let auth = Nip98Auth::from_base64(auth).map_err(|_| "Missing or invalid auth param")?; - if auth.check(&format!("/api/v1/console/{id}"), "GET").is_err() { - return Err("Invalid auth event"); - } - let pubkey = auth.event.pubkey.to_bytes(); - let uid = db.upsert_user(&pubkey).await.map_err(|_| "Insert failed")?; - let vm = db.get_vm(id).await.map_err(|_| "VM not found")?; - if uid != vm.user_id { - return Err("VM does not belong to you"); - } - - Err("Not implemented") -} diff --git a/src/api/webhook.rs b/src/api/webhook.rs index ebe2fed..4eccd34 100644 --- a/src/api/webhook.rs +++ b/src/api/webhook.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use lettre::message::header::Headers; use log::warn; use reqwest::header::HeaderMap; diff --git a/src/dns/cloudflare.rs b/src/dns/cloudflare.rs index 2eebc09..330bca8 100644 --- a/src/dns/cloudflare.rs +++ b/src/dns/cloudflare.rs @@ -1,4 +1,4 @@ -use crate::dns::{BasicRecord, DnsServer}; +use crate::dns::{BasicRecord, DnsServer, RecordType}; use crate::json_api::JsonApi; use lnvps_db::async_trait; use serde::{Deserialize, Serialize}; @@ -39,7 +39,8 @@ impl DnsServer for Cloudflare { Ok(BasicRecord { name: id_response.result.name, value: value.to_string(), - id: id_response.result.id.unwrap(), + id: id_response.result.id, + kind: RecordType::PTR, }) } @@ -67,7 +68,8 @@ impl DnsServer for Cloudflare { Ok(BasicRecord { name: id_response.result.name, value: ip.to_string(), - id: id_response.result.id.unwrap(), + id: id_response.result.id, + kind: RecordType::A, }) } diff --git a/src/dns/mod.rs b/src/dns/mod.rs index 59b8252..952b418 100644 --- a/src/dns/mod.rs +++ b/src/dns/mod.rs @@ -1,5 +1,6 @@ use anyhow::Result; use lnvps_db::async_trait; +use serde::{Deserialize, Serialize}; use std::net::IpAddr; #[cfg(feature = "cloudflare")] @@ -22,9 +23,17 @@ pub trait DnsServer: Send + Sync { async fn delete_a_record(&self, name: &str) -> Result<()>; } +#[derive(Clone, Debug)] +pub enum RecordType { + A, + AAAA, + PTR, +} + #[derive(Debug, Clone)] pub struct BasicRecord { pub name: String, pub value: String, - pub id: String, -} \ No newline at end of file + pub id: Option, + pub kind: RecordType, +} diff --git a/src/host/libvirt.rs b/src/host/libvirt.rs index 0c7f50e..d85229a 100644 --- a/src/host/libvirt.rs +++ b/src/host/libvirt.rs @@ -1,4 +1,4 @@ -use crate::host::{CreateVmRequest, VmHostClient}; +use crate::host::{FullVmInfo, VmHostClient}; use crate::status::VmState; use lnvps_db::{async_trait, Vm, VmOsImage}; @@ -26,7 +26,7 @@ impl VmHostClient for LibVirt { todo!() } - async fn create_vm(&self, cfg: &CreateVmRequest) -> anyhow::Result<()> { + async fn create_vm(&self, cfg: &FullVmInfo) -> anyhow::Result<()> { todo!() } diff --git a/src/host/mod.rs b/src/host/mod.rs index 2a021a5..8e0e8e4 100644 --- a/src/host/mod.rs +++ b/src/host/mod.rs @@ -1,10 +1,12 @@ use crate::settings::ProvisionerConfig; use crate::status::VmState; use anyhow::{bail, Result}; +use futures::future::join_all; use lnvps_db::{ - async_trait, IpRange, UserSshKey, Vm, VmHost, VmHostDisk, VmHostKind, VmIpAssignment, + async_trait, IpRange, LNVpsDb, UserSshKey, Vm, VmHost, VmHostDisk, VmHostKind, VmIpAssignment, VmOsImage, VmTemplate, }; +use std::collections::HashSet; use std::sync::Arc; #[cfg(feature = "libvirt")] @@ -31,13 +33,13 @@ pub trait VmHostClient: Send + Sync { async fn reset_vm(&self, vm: &Vm) -> Result<()>; /// Spawn a VM - async fn create_vm(&self, cfg: &CreateVmRequest) -> Result<()>; + async fn create_vm(&self, cfg: &FullVmInfo) -> Result<()>; /// Get the running status of a VM async fn get_vm_state(&self, vm: &Vm) -> Result; - /// Apply vm configuration (update) - async fn configure_vm(&self, vm: &Vm) -> Result<()>; + /// Apply vm configuration (patch) + async fn configure_vm(&self, cfg: &FullVmInfo) -> Result<()>; } pub fn get_host_client(host: &VmHost, cfg: &ProvisionerConfig) -> Result> { @@ -69,9 +71,8 @@ pub fn get_host_client(host: &VmHost, cfg: &ProvisionerConfig) -> Result) -> Result { + let vm = db.get_vm(vm_id).await?; + let template = db.get_vm_template(vm.template_id).await?; + let image = db.get_os_image(vm.image_id).await?; + let disk = db.get_host_disk(vm.disk_id).await?; + let ssh_key = db.get_user_ssh_key(vm.ssh_key_id).await?; + let ips = db.list_vm_ip_assignments(vm_id).await?; + + let ip_range_ids: HashSet = ips.iter().map(|i| i.ip_range_id).collect(); + let ip_ranges: Vec<_> = ip_range_ids.iter().map(|i| db.get_ip_range(*i)).collect(); + let ranges: Vec = join_all(ip_ranges) + .await + .into_iter() + .filter_map(Result::ok) + .collect(); + + // create VM + Ok(FullVmInfo { + vm, + template, + image, + ips, + disk, + ranges, + ssh_key, + }) + } +} diff --git a/src/host/proxmox.rs b/src/host/proxmox.rs index 956bd6f..4571c88 100644 --- a/src/host/proxmox.rs +++ b/src/host/proxmox.rs @@ -1,4 +1,4 @@ -use crate::host::{CreateVmRequest, VmHostClient}; +use crate::host::{FullVmInfo, VmHostClient}; use crate::json_api::JsonApi; use crate::settings::{QemuConfig, SshConfig}; use crate::ssh_client::SshClient; @@ -8,23 +8,17 @@ use chrono::Utc; use futures::future::join_all; use ipnetwork::IpNetwork; use lnvps_db::{async_trait, DiskType, IpRange, LNVpsDb, Vm, VmIpAssignment, VmOsImage}; -use log::{debug, info}; +use log::{info, warn}; use rand::random; use reqwest::header::{HeaderMap, AUTHORIZATION}; use reqwest::{ClientBuilder, Method, Url}; -use serde::de::value::I32Deserializer; -use serde::de::DeserializeOwned; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::net::IpAddr; use std::str::FromStr; -use std::sync::Arc; use std::time::Duration; -use tokio::net::TcpStream; use tokio::time::sleep; -use tokio_tungstenite::tungstenite::handshake::client::{generate_key, Request}; -use tokio_tungstenite::{Connector, MaybeTlsStream, WebSocketStream}; pub struct ProxmoxClient { api: JsonApi, @@ -359,7 +353,7 @@ impl ProxmoxClient { } impl ProxmoxClient { - fn make_config(&self, value: &CreateVmRequest) -> Result { + fn make_config(&self, value: &FullVmInfo) -> Result { let mut ip_config = value .ips .iter() @@ -474,7 +468,7 @@ impl VmHostClient for ProxmoxClient { Ok(()) } - async fn create_vm(&self, req: &CreateVmRequest) -> Result<()> { + async fn create_vm(&self, req: &FullVmInfo) -> Result<()> { let config = self.make_config(&req)?; let vm_id = req.vm.id.into(); let t_create = self @@ -511,11 +505,14 @@ impl VmHostClient for ProxmoxClient { size: req.template.disk_size.to_string(), }) .await?; + // TODO: rollback self.wait_for_task(&j_resize).await?; // try start, otherwise ignore error (maybe its already running) if let Ok(j_start) = self.start_vm(&self.node, vm_id).await { - self.wait_for_task(&j_start).await?; + if let Err(e) = self.wait_for_task(&j_start).await { + warn!("Failed to start vm: {}", e); + } } Ok(()) @@ -539,8 +536,23 @@ impl VmHostClient for ProxmoxClient { }) } - async fn configure_vm(&self, vm: &Vm) -> Result<()> { - todo!() + async fn configure_vm(&self, cfg: &FullVmInfo) -> Result<()> { + let mut config = self.make_config(&cfg)?; + + // dont re-create the disks + config.scsi_0 = None; + config.scsi_1 = None; + config.efi_disk_0 = None; + + self.configure_vm(ConfigureVm { + node: self.node.clone(), + vm_id: cfg.vm.id.into(), + current: None, + snapshot: None, + config, + }) + .await?; + Ok(()) } } diff --git a/src/lightning/lnd.rs b/src/lightning/lnd.rs index 2afcce2..366b54e 100644 --- a/src/lightning/lnd.rs +++ b/src/lightning/lnd.rs @@ -1,4 +1,3 @@ -use std::path::Path; use crate::lightning::{AddInvoiceRequest, AddInvoiceResult, InvoiceUpdate, LightningNode}; use anyhow::Result; use fedimint_tonic_lnd::invoicesrpc::lookup_invoice_msg::InvoiceRef; @@ -9,6 +8,7 @@ use fedimint_tonic_lnd::{connect, Client}; use futures::StreamExt; use lnvps_db::async_trait; use nostr_sdk::async_utility::futures_util::Stream; +use std::path::Path; use std::pin::Pin; pub struct LndNode { diff --git a/src/lightning/mod.rs b/src/lightning/mod.rs index beeb0cd..b56b69d 100644 --- a/src/lightning/mod.rs +++ b/src/lightning/mod.rs @@ -53,7 +53,10 @@ pub async fn get_node(settings: &Settings) -> Result> { macaroon, } => Ok(Arc::new(lnd::LndNode::new(url, cert, macaroon).await?)), #[cfg(feature = "bitvora")] - LightningConfig::Bitvora { token, webhook_secret } => Ok(Arc::new(bitvora::BitvoraNode::new(token, webhook_secret))), + LightningConfig::Bitvora { + token, + webhook_secret, + } => Ok(Arc::new(bitvora::BitvoraNode::new(token, webhook_secret))), _ => anyhow::bail!("Unsupported lightning config!"), } } diff --git a/src/mocks.rs b/src/mocks.rs index 97920ba..e4b6051 100644 --- a/src/mocks.rs +++ b/src/mocks.rs @@ -1,5 +1,6 @@ -use crate::dns::{BasicRecord, DnsServer}; -use crate::host::{CreateVmRequest, VmHostClient}; +#![allow(unused)] +use crate::dns::{BasicRecord, DnsServer, RecordType}; +use crate::host::{FullVmInfo, VmHostClient}; use crate::lightning::{AddInvoiceRequest, AddInvoiceResult, InvoiceUpdate, LightningNode}; use crate::router::{ArpEntry, Router}; use crate::settings::NetworkPolicy; @@ -515,23 +516,21 @@ impl Router for MockRouter { mac: &str, interface: &str, comment: Option<&str>, - ) -> anyhow::Result<()> { + ) -> anyhow::Result { let mut arp = self.arp.lock().await; if arp.iter().any(|(k, v)| v.address == ip.to_string()) { bail!("Address is already in use"); } let max_id = *arp.keys().max().unwrap_or(&0); - arp.insert( - max_id + 1, - ArpEntry { - id: Some((max_id + 1).to_string()), - address: ip.to_string(), - mac_address: Some(mac.to_string()), - interface: interface.to_string(), - comment: comment.map(|s| s.to_string()), - }, - ); - Ok(()) + let e = ArpEntry { + id: (max_id + 1).to_string(), + address: ip.to_string(), + mac_address: mac.to_string(), + interface: Some(interface.to_string()), + comment: comment.map(|s| s.to_string()), + }; + arp.insert(max_id + 1, e.clone()); + Ok(e) } async fn remove_arp_entry(&self, id: &str) -> anyhow::Result<()> { @@ -636,7 +635,7 @@ impl VmHostClient for MockVmHost { Ok(()) } - async fn create_vm(&self, cfg: &CreateVmRequest) -> anyhow::Result<()> { + async fn create_vm(&self, cfg: &FullVmInfo) -> anyhow::Result<()> { let mut vms = self.vms.lock().await; let max_id = *vms.keys().max().unwrap_or(&0); vms.insert( @@ -717,7 +716,8 @@ impl DnsServer for MockDnsServer { Ok(BasicRecord { name: format!("{}.X.Y.Z.in-addr.arpa", key), value: value.to_string(), - id, + id: Some(id), + kind: RecordType::PTR, }) } @@ -746,7 +746,8 @@ impl DnsServer for MockDnsServer { Ok(BasicRecord { name: fqdn, value: ip.to_string(), - id, + id: Some(id), + kind: RecordType::A, }) } diff --git a/src/provisioner/lnvps.rs b/src/provisioner/lnvps.rs index 96683f1..5957d7a 100644 --- a/src/provisioner/lnvps.rs +++ b/src/provisioner/lnvps.rs @@ -1,6 +1,6 @@ use crate::dns::DnsServer; use crate::exchange::{ExchangeRateService, Ticker}; -use crate::host::{get_host_client, CreateVmRequest, VmHostClient}; +use crate::host::{get_host_client, FullVmInfo}; use crate::lightning::{AddInvoiceRequest, LightningNode}; use crate::provisioner::{NetworkProvisioner, ProvisionerMethod}; use crate::router::Router; @@ -8,19 +8,16 @@ use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Set use anyhow::{bail, Result}; use chrono::{Days, Months, Utc}; use futures::future::join_all; -use lnvps_db::{DiskType, IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment}; -use log::{debug, info, warn}; +use lnvps_db::{IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment}; +use log::{info, warn}; use nostr::util::hex; use rand::random; -use rocket::futures::{SinkExt, StreamExt}; -use std::collections::{HashMap, HashSet}; -use std::fmt::format; +use std::collections::HashSet; use std::net::IpAddr; use std::ops::Add; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::net::TcpStream; /// Main provisioner class for LNVPS /// @@ -62,14 +59,11 @@ impl LNVpsProvisioner { if let NetworkAccessPolicy::StaticArp { .. } = &self.network_policy.access { if let Some(r) = self.router.as_ref() { let ent = r.list_arp_entry().await?; - if let Some(ent) = ent.iter().find(|e| { - e.mac_address - .as_ref() - .map(|m| m.eq_ignore_ascii_case(&vm.mac_address)) - .unwrap_or(false) - }) { - r.remove_arp_entry(ent.id.as_ref().unwrap().as_str()) - .await?; + if let Some(ent) = ent + .iter() + .find(|e| e.mac_address.eq_ignore_ascii_case(&vm.mac_address)) + { + r.remove_arp_entry(&ent.id).await?; } else { warn!("ARP entry not found, skipping") } @@ -101,14 +95,14 @@ impl LNVpsProvisioner { let sub_name = format!("vm-{}", vm.id); let fwd = dns.add_a_record(&sub_name, ip.clone()).await?; assignment.dns_forward = Some(fwd.name.clone()); - assignment.dns_forward_ref = Some(fwd.id); + assignment.dns_forward_ref = fwd.id; match ip { IpAddr::V4(ip) => { let last_octet = ip.octets()[3].to_string(); let rev = dns.add_ptr_record(&last_octet, &fwd.name).await?; assignment.dns_reverse = Some(fwd.name.clone()); - assignment.dns_reverse_ref = Some(rev.id); + assignment.dns_reverse_ref = rev.id; } IpAddr::V6(_) => { warn!("IPv6 forward DNS not supported yet") @@ -290,40 +284,16 @@ impl LNVpsProvisioner { if self.read_only { bail!("Cant spawn VM's in read-only mode") } - let vm = self.db.get_vm(vm_id).await?; - let template = self.db.get_vm_template(vm.template_id).await?; - let host = self.db.get_host(vm.host_id).await?; - let image = self.db.get_os_image(vm.image_id).await?; - let disk = self.db.get_host_disk(vm.disk_id).await?; - let ssh_key = self.db.get_user_ssh_key(vm.ssh_key_id).await?; - - let client = get_host_client(&host, &self.provisioner_config)?; - // setup network by allocating some IP space - let ips = self.allocate_ips(vm.id).await?; + self.allocate_ips(vm_id).await?; - let ip_range_ids: HashSet = ips.iter().map(|i| i.ip_range_id).collect(); - let ip_ranges: Vec<_> = ip_range_ids - .iter() - .map(|i| self.db.get_ip_range(*i)) - .collect(); - let ranges: Vec = join_all(ip_ranges) - .await - .into_iter() - .filter_map(Result::ok) - .collect(); + // load full info + let info = FullVmInfo::load(vm_id, self.db.clone()).await?; - // create VM - let req = CreateVmRequest { - vm, - template, - image, - ips, - disk, - ranges, - ssh_key, - }; - client.create_vm(&req).await?; + // load host client + let host = self.db.get_host(info.vm.host_id).await?; + let client = get_host_client(&host, &self.provisioner_config)?; + client.create_vm(&info).await?; Ok(()) } @@ -410,7 +380,6 @@ mod tests { let node = Arc::new(MockNode::default()); let rates = Arc::new(DefaultRateCache::default()); let router = settings.get_router().expect("router").unwrap(); - let dns = settings.get_dns().expect("dns").unwrap(); let provisioner = LNVpsProvisioner::new(settings, db.clone(), node.clone(), rates.clone()); let pubkey: [u8; 32] = random(); @@ -433,8 +402,8 @@ mod tests { let arp = router.list_arp_entry().await?; assert_eq!(1, arp.len()); let arp = arp.first().unwrap(); - assert_eq!(&vm.mac_address, arp.mac_address.as_ref().unwrap()); - assert_eq!(ROUTER_BRIDGE, &arp.interface); + assert_eq!(&vm.mac_address, &arp.mac_address); + assert_eq!(ROUTER_BRIDGE, arp.interface.as_ref().unwrap()); println!("{:?}", arp); let ips = db.list_vm_ip_assignments(vm.id).await?; diff --git a/src/router/mikrotik.rs b/src/router/mikrotik.rs index 56f5d67..3315946 100644 --- a/src/router/mikrotik.rs +++ b/src/router/mikrotik.rs @@ -3,8 +3,10 @@ use crate::router::{ArpEntry, Router}; use anyhow::Result; use base64::engine::general_purpose::STANDARD; use base64::Engine; +use log::debug; use reqwest::Method; use rocket::async_trait; +use serde::{Deserialize, Serialize}; use std::net::IpAddr; pub struct MikrotikRouter { @@ -26,8 +28,8 @@ impl MikrotikRouter { #[async_trait] impl Router for MikrotikRouter { async fn list_arp_entry(&self) -> Result> { - let rsp: Vec = self.api.req(Method::GET, "/rest/ip/arp", ()).await?; - Ok(rsp) + let rsp: Vec = self.api.req(Method::GET, "/rest/ip/arp", ()).await?; + Ok(rsp.into_iter().map(|e| e.into()).collect()) } async fn add_arp_entry( @@ -36,31 +38,57 @@ impl Router for MikrotikRouter { mac: &str, arp_interface: &str, comment: Option<&str>, - ) -> Result<()> { - let _rsp: ArpEntry = self + ) -> Result { + let rsp: MikrotikArpEntry = self .api .req( Method::PUT, "/rest/ip/arp", - ArpEntry { + MikrotikArpEntry { + id: None, address: ip.to_string(), mac_address: Some(mac.to_string()), interface: arp_interface.to_string(), comment: comment.map(|c| c.to_string()), - ..Default::default() }, ) .await?; - - Ok(()) + debug!("{:?}", rsp); + Ok(rsp.into()) } async fn remove_arp_entry(&self, id: &str) -> Result<()> { - let _rsp: ArpEntry = self + let rsp: MikrotikArpEntry = self .api .req(Method::DELETE, &format!("/rest/ip/arp/{id}"), ()) .await?; - + debug!("{:?}", rsp); Ok(()) } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MikrotikArpEntry { + #[serde(rename = ".id")] + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub address: String, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "mac-address")] + pub mac_address: Option, + pub interface: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub comment: Option, +} + +impl Into for MikrotikArpEntry { + fn into(self) -> ArpEntry { + ArpEntry { + id: self.id.unwrap(), + address: self.address, + mac_address: self.mac_address.unwrap(), + interface: Some(self.interface), + comment: self.comment, + } + } +} diff --git a/src/router/mod.rs b/src/router/mod.rs index 1c957c7..c685969 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,6 +1,5 @@ use anyhow::Result; use rocket::async_trait; -use rocket::serde::{Deserialize, Serialize}; use std::net::IpAddr; /// Router defines a network device used to access the hosts @@ -19,21 +18,16 @@ pub trait Router: Send + Sync { mac: &str, interface: &str, comment: Option<&str>, - ) -> Result<()>; + ) -> Result; async fn remove_arp_entry(&self, id: &str) -> Result<()>; } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone)] pub struct ArpEntry { - #[serde(rename = ".id")] - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, + pub id: String, pub address: String, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "mac-address")] - pub mac_address: Option, - pub interface: String, - #[serde(skip_serializing_if = "Option::is_none")] + pub mac_address: String, + pub interface: Option, pub comment: Option, } diff --git a/src/settings.rs b/src/settings.rs index fec31f2..04540e2 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -181,7 +181,7 @@ impl Settings { pub fn get_router(&self) -> Result>> { #[cfg(test)] { - if let Some(router) = &self.router { + if let Some(_router) = &self.router { let router = crate::mocks::MockRouter::new(self.network_policy.clone()); Ok(Some(Arc::new(router))) } else {