From 45dd0c43986aa9d9158c945795d2d6632aeeac57 Mon Sep 17 00:00:00 2001 From: kieran Date: Tue, 11 Mar 2025 12:42:25 +0000 Subject: [PATCH] feat: fiat payments (revolut) ref: #24 --- Cargo.lock | 2 + Cargo.toml | 9 +- config.yaml | 1 + .../20250310153305_fiat_payment.sql | 6 + lnvps_db/src/lib.rs | 3 + lnvps_db/src/model.rs | 49 +++++- lnvps_db/src/mysql.rs | 20 ++- src/api/mod.rs | 1 + src/api/model.rs | 82 ++++++++-- src/api/routes.rs | 64 ++++++-- src/api/webhook.rs | 24 ++- src/bin/api.rs | 20 +-- src/data_migration/dns.rs | 6 +- src/dns/cloudflare.rs | 8 +- src/exchange.rs | 12 +- src/fiat/mod.rs | 25 +++ src/fiat/revolut.rs | 154 ++++++++++++++++++ src/host/libvirt.rs | 2 +- src/json_api.rs | 32 +++- src/lib.rs | 6 +- src/lightning/bitvora.rs | 2 +- src/lightning/lnd.rs | 1 - src/lightning/mod.rs | 1 - src/mocks.rs | 12 +- src/nip98.rs | 2 +- src/{ => payments}/invoice.rs | 16 +- src/payments/mod.rs | 56 +++++++ src/payments/revolut.rs | 129 +++++++++++++++ src/provisioner/lnvps.rs | 113 +++++++++---- src/provisioner/mod.rs | 2 +- src/provisioner/pricing.rs | 86 ++++++---- src/settings.rs | 27 +++ 32 files changed, 822 insertions(+), 151 deletions(-) create mode 100644 lnvps_db/migrations/20250310153305_fiat_payment.sql create mode 100644 src/fiat/mod.rs create mode 100644 src/fiat/revolut.rs rename src/{ => payments}/invoice.rs (76%) create mode 100644 src/payments/mod.rs create mode 100644 src/payments/revolut.rs diff --git a/Cargo.lock b/Cargo.lock index cdfab75..40761f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2009,6 +2009,7 @@ dependencies = [ "fern", "futures", "hex", + "hmac", "ipnetwork", "lettre", "lnvps_db", @@ -2024,6 +2025,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "sha2", "ssh-key", "ssh2", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 4afeee0..914d266 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" name = "api" [features] -default = ["mikrotik", "nostr-dm", "proxmox", "lnd", "cloudflare"] +default = ["mikrotik", "nostr-dm", "proxmox", "lnd", "cloudflare", "revolut"] mikrotik = ["dep:reqwest"] nostr-dm = ["dep:nostr-sdk"] proxmox = ["dep:reqwest", "dep:ssh2", "dep:tokio-tungstenite"] @@ -15,6 +15,7 @@ libvirt = ["dep:virt"] lnd = ["dep:fedimint-tonic-lnd"] bitvora = ["dep:reqwest", "dep:tokio-stream"] cloudflare = ["dep:reqwest"] +revolut = ["dep:reqwest", "dep:sha2", "dep:hmac"] [dependencies] lnvps_db = { path = "lnvps_db" } @@ -57,4 +58,8 @@ virt = { version = "0.4.2", optional = true } fedimint-tonic-lnd = { version = "0.2.0", default-features = false, features = ["invoicesrpc"], optional = true } #bitvora -tokio-stream = { version = "0.1.17", features = ["sync"], optional = true } \ No newline at end of file +tokio-stream = { version = "0.1.17", features = ["sync"], optional = true } + +#revolut +sha2 = { version = "0.10.8", optional = true } +hmac = { version = "0.12.1", optional = true } \ No newline at end of file diff --git a/config.yaml b/config.yaml index 4ea6fd9..d1a2398 100644 --- a/config.yaml +++ b/config.yaml @@ -5,6 +5,7 @@ lightning: cert: "/home/kieran/.polar/networks/2/volumes/lnd/alice/tls.cert" macaroon: "/home/kieran/.polar/networks/2/volumes/lnd/alice/data/chain/bitcoin/regtest/admin.macaroon" delete-after: 3 +public-url: "https://api.lnvps.net" provisioner: proxmox: read-only: false diff --git a/lnvps_db/migrations/20250310153305_fiat_payment.sql b/lnvps_db/migrations/20250310153305_fiat_payment.sql new file mode 100644 index 0000000..536ef2f --- /dev/null +++ b/lnvps_db/migrations/20250310153305_fiat_payment.sql @@ -0,0 +1,6 @@ +alter table vm_payment + add column currency varchar(5) not null default 'BTC', + add column payment_method smallint unsigned not null default 0, + add column external_id varchar(255), + change invoice external_data varchar (4096) NOT NULL, + drop column settle_index; diff --git a/lnvps_db/src/lib.rs b/lnvps_db/src/lib.rs index b024111..6f951e3 100644 --- a/lnvps_db/src/lib.rs +++ b/lnvps_db/src/lib.rs @@ -131,6 +131,9 @@ pub trait LNVpsDb: Sync + Send { /// Get VM payment by payment id async fn get_vm_payment(&self, id: &Vec) -> Result; + /// Get VM payment by payment id + async fn get_vm_payment_by_ext_id(&self, id: &str) -> Result; + /// Update a VM payment record async fn update_vm_payment(&self, vm_payment: &VmPayment) -> Result<()>; diff --git a/lnvps_db/src/model.rs b/lnvps_db/src/model.rs index 80d3640..7ca44ad 100644 --- a/lnvps_db/src/model.rs +++ b/lnvps_db/src/model.rs @@ -1,8 +1,9 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use chrono::{DateTime, Utc}; -use sqlx::FromRow; +use sqlx::{FromRow, Type}; use std::fmt::{Display, Formatter}; use std::path::PathBuf; +use std::str::FromStr; use url::Url; #[derive(FromRow, Clone, Debug)] @@ -309,17 +310,53 @@ impl Display for VmIpAssignment { #[derive(FromRow, Clone, Debug, Default)] pub struct VmPayment { - /// Payment hash pub id: Vec, pub vm_id: u64, pub created: DateTime, pub expires: DateTime, pub amount: u64, - pub invoice: String, + pub currency: String, + pub payment_method: PaymentMethod, + /// External data (invoice / json) + pub external_data: String, + /// External id on other system + pub external_id: Option, pub is_paid: bool, - /// Exchange rate + /// TODO: handle other base currencies + /// Exchange rate back to base currency (EUR) pub rate: f32, /// Number of seconds this payment will add to vm expiry pub time_value: u64, - pub settle_index: Option, +} + +#[derive(Type, Clone, Copy, Debug, Default, PartialEq)] +#[repr(u16)] +pub enum PaymentMethod { + #[default] + Lightning, + Revolut, + Paypal, +} + +impl Display for PaymentMethod { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + PaymentMethod::Lightning => write!(f, "Lightning"), + PaymentMethod::Revolut => write!(f, "Revolut"), + PaymentMethod::Paypal => write!(f, "PayPal"), + } + } +} + +impl FromStr for PaymentMethod { + type Err = anyhow::Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "lightning" => Ok(PaymentMethod::Lightning), + "revolut" => Ok(PaymentMethod::Revolut), + "paypal" => Ok(PaymentMethod::Paypal), + _ => bail!("Unknown payment method: {}", s), + } + } } diff --git a/lnvps_db/src/mysql.rs b/lnvps_db/src/mysql.rs index 376125b..975fcee 100644 --- a/lnvps_db/src/mysql.rs +++ b/lnvps_db/src/mysql.rs @@ -387,16 +387,19 @@ impl LNVpsDb for LNVpsDbMysql { } async fn insert_vm_payment(&self, vm_payment: &VmPayment) -> Result<()> { - sqlx::query("insert into vm_payment(id,vm_id,created,expires,amount,invoice,time_value,is_paid,rate) values(?,?,?,?,?,?,?,?,?)") + sqlx::query("insert into vm_payment(id,vm_id,created,expires,amount,currency,payment_method,time_value,is_paid,rate,external_id,external_data) values(?,?,?,?,?,?,?,?,?,?,?,?)") .bind(&vm_payment.id) .bind(vm_payment.vm_id) .bind(vm_payment.created) .bind(vm_payment.expires) .bind(vm_payment.amount) - .bind(&vm_payment.invoice) + .bind(&vm_payment.currency) + .bind(&vm_payment.payment_method) .bind(vm_payment.time_value) .bind(vm_payment.is_paid) .bind(vm_payment.rate) + .bind(&vm_payment.external_id) + .bind(&vm_payment.external_data) .execute(&self.db) .await .map_err(Error::new)?; @@ -411,6 +414,14 @@ impl LNVpsDb for LNVpsDbMysql { .map_err(Error::new) } + async fn get_vm_payment_by_ext_id(&self, id: &str) -> Result { + sqlx::query_as("select * from vm_payment where external_id=?") + .bind(id) + .fetch_one(&self.db) + .await + .map_err(Error::new) + } + async fn update_vm_payment(&self, vm_payment: &VmPayment) -> Result<()> { sqlx::query("update vm_payment set is_paid = ? where id = ?") .bind(vm_payment.is_paid) @@ -428,8 +439,7 @@ impl LNVpsDb for LNVpsDbMysql { let mut tx = self.db.begin().await?; - sqlx::query("update vm_payment set is_paid = true, settle_index = ? where id = ?") - .bind(vm_payment.settle_index) + sqlx::query("update vm_payment set is_paid = true where id = ?") .bind(&vm_payment.id) .execute(&mut *tx) .await?; @@ -446,7 +456,7 @@ impl LNVpsDb for LNVpsDbMysql { async fn last_paid_invoice(&self) -> Result> { sqlx::query_as( - "select * from vm_payment where is_paid = true order by settle_index desc limit 1", + "select * from vm_payment where is_paid = true order by created desc limit 1", ) .fetch_optional(&self.db) .await diff --git a/src/api/mod.rs b/src/api/mod.rs index f6e8b80..a76d55b 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -10,4 +10,5 @@ pub fn routes() -> Vec { r } +pub use webhook::WebhookMessage; pub use webhook::WEBHOOK_BRIDGE; diff --git a/src/api/model.rs b/src/api/model.rs index d7b6fa7..30ed4c3 100644 --- a/src/api/model.rs +++ b/src/api/model.rs @@ -1,16 +1,16 @@ use crate::exchange::{alt_prices, Currency, CurrencyAmount, ExchangeRateService}; -use crate::provisioner::{PricingData, PricingEngine}; +use crate::provisioner::PricingEngine; use crate::status::VmState; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{anyhow, bail, Result}; use chrono::{DateTime, Utc}; use ipnetwork::IpNetwork; use lnvps_db::{ - LNVpsDb, Vm, VmCostPlan, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate, VmHost, - VmHostRegion, VmTemplate, + LNVpsDb, PaymentMethod, Vm, VmCostPlan, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate, VmHostRegion, VmTemplate, }; use nostr::util::hex; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -95,9 +95,9 @@ impl From for DiskType { } } -impl Into for DiskType { - fn into(self) -> lnvps_db::DiskType { - match self { +impl From for lnvps_db::DiskType { + fn from(val: DiskType) -> Self { + match val { DiskType::HDD => lnvps_db::DiskType::HDD, DiskType::SSD => lnvps_db::DiskType::SSD, } @@ -143,7 +143,7 @@ impl ApiTemplatesResponse { pub async fn expand_pricing(&mut self, rates: &Arc) -> Result<()> { let rates = rates.list_rates().await?; - for mut template in &mut self.templates { + for template in &mut self.templates { let list_price = CurrencyAmount(template.cost_plan.currency, template.cost_plan.amount); for alt_price in alt_prices(&rates, list_price) { template.cost_plan.other_price.push(ApiPrice { @@ -335,8 +335,8 @@ impl ApiVmTemplate { cpu: template.cpu, memory: template.memory, disk_size: template.disk_size, - disk_type: template.disk_type.clone().into(), - disk_interface: template.disk_interface.clone().into(), + disk_type: template.disk_type.into(), + disk_interface: template.disk_interface.into(), cost_plan: ApiVmCostPlan { id: cost_plan.id, name: cost_plan.name.clone(), @@ -468,14 +468,14 @@ impl From for ApiVmOsImage { #[derive(Serialize, Deserialize, JsonSchema)] pub struct ApiVmPayment { - /// Payment hash hex pub id: String, pub vm_id: u64, pub created: DateTime, pub expires: DateTime, pub amount: u64, - pub invoice: String, + pub currency: String, pub is_paid: bool, + pub data: ApiPaymentData, } impl From for ApiVmPayment { @@ -486,8 +486,64 @@ impl From for ApiVmPayment { created: value.created, expires: value.expires, amount: value.amount, - invoice: value.invoice, + currency: value.currency, is_paid: value.is_paid, + data: match &value.payment_method { + PaymentMethod::Lightning => ApiPaymentData::Lightning(value.external_data), + PaymentMethod::Revolut => { + #[derive(Deserialize)] + struct RevolutData { + pub token: String, + } + let data: RevolutData = serde_json::from_str(&value.external_data).unwrap(); + ApiPaymentData::Revolut { token: data.token } + } + PaymentMethod::Paypal => { + todo!() + } + }, + } + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct ApiPaymentInfo { + pub name: ApiPaymentMethod, + + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub metadata: HashMap, + + pub currencies: Vec, +} + +/// Payment data related to the payment method +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum ApiPaymentData { + /// Just an LN invoice + Lightning(String), + /// Revolut order data + Revolut { + /// Order token + token: String, + }, +} + +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum ApiPaymentMethod { + #[default] + Lightning, + Revolut, + Paypal, +} + +impl From for ApiPaymentMethod { + fn from(value: PaymentMethod) -> Self { + match value { + PaymentMethod::Lightning => ApiPaymentMethod::Lightning, + PaymentMethod::Revolut => ApiPaymentMethod::Revolut, + PaymentMethod::Paypal => ApiPaymentMethod::Paypal, } } } diff --git a/src/api/routes.rs b/src/api/routes.rs index 4d4c4a2..d5bddb8 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -1,19 +1,21 @@ use crate::api::model::{ - AccountPatchRequest, ApiCustomTemplateDiskParam, ApiCustomTemplateParams, ApiCustomVmOrder, - ApiCustomVmRequest, ApiPrice, ApiTemplatesResponse, ApiUserSshKey, ApiVmHostRegion, - ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus, ApiVmTemplate, CreateSshKey, - CreateVmRequest, VMPatchRequest, + AccountPatchRequest, ApiCustomTemplateParams, ApiCustomVmOrder, + ApiCustomVmRequest, ApiPaymentInfo, ApiPaymentMethod, ApiPrice, ApiTemplatesResponse, + ApiUserSshKey, ApiVmIpAssignment, ApiVmOsImage, ApiVmPayment, ApiVmStatus, + ApiVmTemplate, CreateSshKey, CreateVmRequest, VMPatchRequest, }; -use crate::exchange::ExchangeRateService; +use crate::exchange::{Currency, ExchangeRateService}; use crate::host::{get_host_client, FullVmInfo, TimeSeries, TimeSeriesData}; use crate::nip98::Nip98Auth; use crate::provisioner::{HostCapacityService, LNVpsProvisioner, PricingEngine}; use crate::settings::Settings; use crate::status::{VmState, VmStateCache}; use crate::worker::WorkJob; -use anyhow::{Context, Result}; +use anyhow::Result; use futures::future::join_all; -use lnvps_db::{IpRange, LNVpsDb, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate}; +use lnvps_db::{ + IpRange, LNVpsDb, PaymentMethod, VmCustomPricing, VmCustomPricingDisk, VmCustomTemplate, +}; use nostr::util::hex; use rocket::futures::{SinkExt, StreamExt}; use rocket::serde::json::Json; @@ -26,6 +28,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use ssh_key::PublicKey; use std::collections::{HashMap, HashSet}; +use std::str::FromStr; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; @@ -48,7 +51,8 @@ pub fn routes() -> Vec { v1_patch_vm, v1_time_series, v1_custom_template_calc, - v1_create_custom_vm_order + v1_create_custom_vm_order, + v1_get_payment_methods ] } @@ -143,7 +147,7 @@ async fn vm_to_status( .map(|i| (i.id, i)) .collect(); - let template = ApiVmTemplate::from_vm(&db, &vm).await?; + let template = ApiVmTemplate::from_vm(db, &vm).await?; Ok(ApiVmStatus { id: vm.id, created: vm.created, @@ -309,7 +313,7 @@ async fn v1_list_vm_templates( }) .collect(); let custom_templates: Vec = - join_all(regions.iter().map(|(k, _)| db.list_custom_pricing(*k))) + join_all(regions.keys().map(|k| db.list_custom_pricing(*k))) .await .into_iter() .filter_map(|r| r.ok()) @@ -344,8 +348,7 @@ async fn v1_list_vm_templates( .into_iter() .filter_map(|t| { let region = regions.get(&t.region_id)?; - Some( - ApiCustomTemplateParams::from( + ApiCustomTemplateParams::from( &t, &custom_template_disks, region, @@ -353,8 +356,7 @@ async fn v1_list_vm_templates( max_memory, max_disk, ) - .ok()?, - ) + .ok() }) .collect(), ) @@ -376,7 +378,7 @@ async fn v1_custom_template_calc( let price = PricingEngine::get_custom_vm_cost_amount(db, 0, &template).await?; ApiData::ok(ApiPrice { - currency: price.currency.clone(), + currency: price.currency, amount: price.total(), }) } @@ -484,12 +486,13 @@ async fn v1_create_vm_order( /// Renew(Extend) a VM #[openapi(tag = "VM")] -#[get("/api/v1/vm//renew")] +#[get("/api/v1/vm//renew?")] async fn v1_renew_vm( auth: Nip98Auth, db: &State>, provisioner: &State>, id: u64, + method: Option<&str>, ) -> ApiResult { let pubkey = auth.event.pubkey.to_bytes(); let uid = db.upsert_user(&pubkey).await?; @@ -498,7 +501,14 @@ async fn v1_renew_vm( return ApiData::err("VM does not belong to you"); } - let rsp = provisioner.renew(id).await?; + let rsp = provisioner + .renew( + id, + method + .and_then(|m| PaymentMethod::from_str(m).ok()) + .unwrap_or(PaymentMethod::Lightning), + ) + .await?; ApiData::ok(rsp.into()) } @@ -596,6 +606,26 @@ async fn v1_time_series( ApiData::ok(client.get_time_series_data(&vm, TimeSeries::Hourly).await?) } +#[openapi(tag = "Payment")] +#[get("/api/v1/payment/methods")] +async fn v1_get_payment_methods(settings: &State) -> ApiResult> { + let mut ret = vec![ApiPaymentInfo { + name: ApiPaymentMethod::Lightning, + metadata: HashMap::new(), + currencies: vec![Currency::BTC], + }]; + #[cfg(feature = "revolut")] + if let Some(r) = &settings.revolut { + ret.push(ApiPaymentInfo { + name: ApiPaymentMethod::Revolut, + metadata: HashMap::from([("pubkey".to_string(), r.public_key.to_string())]), + currencies: vec![Currency::EUR, Currency::USD], + }) + } + + ApiData::ok(ret) +} + /// Get payment status (for polling) #[openapi(tag = "Payment")] #[get("/api/v1/payment/")] diff --git a/src/api/webhook.rs b/src/api/webhook.rs index 85a9ead..6f34718 100644 --- a/src/api/webhook.rs +++ b/src/api/webhook.rs @@ -6,23 +6,35 @@ use std::collections::HashMap; use std::sync::LazyLock; use tokio::sync::broadcast; -/// Messaging bridge for webhooks to other parts of the system (bitvora) +/// Messaging bridge for webhooks to other parts of the system (bitvora/revout) pub static WEBHOOK_BRIDGE: LazyLock = LazyLock::new(WebhookBridge::new); pub fn routes() -> Vec { - if cfg!(feature = "bitvora") { - routes![bitvora_webhook] - } else { - routes![] - } + let mut routes = vec![]; + + #[cfg(feature = "bitvora")] + routes.append(&mut routes![bitvora_webhook]); + + #[cfg(feature = "revolut")] + routes.append(&mut routes![revolut_webhook]); + + routes } +#[cfg(feature = "bitvora")] #[post("/api/v1/webhook/bitvora", data = "")] async fn bitvora_webhook(req: WebhookMessage) -> Status { WEBHOOK_BRIDGE.send(req); Status::Ok } +#[cfg(feature = "revolut")] +#[post("/api/v1/webhook/revolut", data = "")] +async fn revolut_webhook(req: WebhookMessage) -> Status { + WEBHOOK_BRIDGE.send(req); + Status::Ok +} + #[derive(Debug, Clone)] pub struct WebhookMessage { pub body: Vec, diff --git a/src/bin/api.rs b/src/bin/api.rs index 5338c9b..9bcaba8 100644 --- a/src/bin/api.rs +++ b/src/bin/api.rs @@ -6,8 +6,8 @@ use lnvps::api; use lnvps::cors::CORS; use lnvps::data_migration::run_data_migrations; use lnvps::exchange::{DefaultRateCache, ExchangeRateService}; -use lnvps::invoice::InvoiceHandler; use lnvps::lightning::get_node; +use lnvps::payments::listen_all_payments; use lnvps::settings::Settings; use lnvps::status::VmStateCache; use lnvps::worker::{WorkJob, Worker}; @@ -20,7 +20,6 @@ use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; -use tokio::time::sleep; #[derive(Parser)] #[clap(about, version, author)] @@ -37,7 +36,7 @@ struct Args { #[rocket::main] async fn main() -> Result<(), Error> { let log_level = std::env::var("RUST_LOG") - .unwrap_or_else(|_| "info".to_string()) // Default to "info" if not set + .unwrap_or_else(|_| "info".to_string()) .to_lowercase(); let max_level = match log_level.as_str() { @@ -47,7 +46,7 @@ async fn main() -> Result<(), Error> { "warn" => LevelFilter::Warn, "error" => LevelFilter::Error, "off" => LevelFilter::Off, - _ => LevelFilter::Info, + _ => LevelFilter::Debug, }; let args = Args::parse(); @@ -121,15 +120,10 @@ async fn main() -> Result<(), Error> { } } }); - let mut handler = InvoiceHandler::new(node.clone(), db.clone(), sender.clone()); - tokio::spawn(async move { - loop { - if let Err(e) = handler.listen().await { - error!("invoice-error: {}", e); - } - sleep(Duration::from_secs(5)).await; - } - }); + + // setup payment handlers + listen_all_payments(&settings, node.clone(), db.clone(), sender.clone())?; + // request work every 30s to check vm status let sender_clone = sender.clone(); tokio::spawn(async move { diff --git a/src/data_migration/dns.rs b/src/data_migration/dns.rs index 9c4d273..26c134c 100644 --- a/src/data_migration/dns.rs +++ b/src/data_migration/dns.rs @@ -28,17 +28,17 @@ impl DataMigration for DnsDataMigration { for vm in vms { let mut ips = db.list_vm_ip_assignments(vm.id).await?; - for mut ip in &mut ips { + for ip in &mut ips { let mut did_change = false; if ip.dns_forward.is_none() { - let rec = BasicRecord::forward(&ip)?; + let rec = BasicRecord::forward(ip)?; let r = dns.add_record(&rec).await?; ip.dns_forward = Some(r.name); ip.dns_forward_ref = r.id; did_change = true; } if ip.dns_reverse.is_none() { - let rec = BasicRecord::reverse_to_fwd(&ip)?; + let rec = BasicRecord::reverse_to_fwd(ip)?; let r = dns.add_record(&rec).await?; ip.dns_reverse = Some(r.value); ip.dns_reverse_ref = r.id; diff --git a/src/dns/cloudflare.rs b/src/dns/cloudflare.rs index 327f262..d42dd8f 100644 --- a/src/dns/cloudflare.rs +++ b/src/dns/cloudflare.rs @@ -14,8 +14,12 @@ pub struct Cloudflare { impl Cloudflare { pub fn new(token: &str, reverse_zone_id: &str, forward_zone_id: &str) -> Cloudflare { Self { - api: JsonApi::token("https://api.cloudflare.com", &format!("Bearer {}", token), false) - .unwrap(), + api: JsonApi::token( + "https://api.cloudflare.com", + &format!("Bearer {}", token), + false, + ) + .unwrap(), reverse_zone_id: reverse_zone_id.to_owned(), forward_zone_id: forward_zone_id.to_owned(), } diff --git a/src/exchange.rs b/src/exchange.rs index 834bb42..a009667 100644 --- a/src/exchange.rs +++ b/src/exchange.rs @@ -1,10 +1,10 @@ -use anyhow::{anyhow, ensure, Context, Error, Result}; +use anyhow::{anyhow, ensure, Result}; use lnvps_db::async_trait; use log::info; use rocket::serde::Deserialize; use schemars::JsonSchema; use serde::Serialize; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::str::FromStr; use std::sync::Arc; @@ -62,6 +62,12 @@ pub struct TickerRate(pub Ticker, pub f32); #[derive(Clone, Copy, Debug, PartialEq)] pub struct CurrencyAmount(pub Currency, pub f32); +impl CurrencyAmount { + pub fn from_u64(currency: Currency, amount: u64) -> Self { + CurrencyAmount(currency, amount as f32 / 100.0) + } +} + impl TickerRate { pub fn can_convert(&self, currency: Currency) -> bool { currency == self.0 .0 || currency == self.0 .1 @@ -99,7 +105,7 @@ pub fn alt_prices(rates: &Vec, source: CurrencyAmount) -> Vec Pin> + Send>>; +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FiatPaymentInfo { + pub external_id: String, + pub raw_data: String, +} diff --git a/src/fiat/revolut.rs b/src/fiat/revolut.rs new file mode 100644 index 0000000..d904fca --- /dev/null +++ b/src/fiat/revolut.rs @@ -0,0 +1,154 @@ +use crate::exchange::{Currency, CurrencyAmount}; +use crate::fiat::{FiatPaymentInfo, FiatPaymentService}; +use crate::json_api::JsonApi; +use crate::settings::RevolutConfig; +use anyhow::{bail, Result}; +use chrono::{DateTime, Utc}; +use reqwest::header::{HeaderMap, ACCEPT, AUTHORIZATION}; +use reqwest::{Client, Method}; +use serde::{Deserialize, Serialize}; +use std::future::Future; +use std::pin::Pin; + +pub struct RevolutApi { + api: JsonApi, +} + +impl RevolutApi { + pub fn new(config: RevolutConfig) -> Result { + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, format!("Bearer {}", config.token).parse()?); + headers.insert(ACCEPT, "application/json".parse()?); + headers.insert("Revolut-Api-Version", config.api_version.parse()?); + + let client = Client::builder().default_headers(headers).build()?; + Ok(Self { + api: JsonApi { + client, + base: config + .url + .unwrap_or("https://merchant.revolut.com".to_string()) + .parse()?, + }, + }) + } + + pub async fn list_webhooks(&self) -> Result> { + self.api.get("/api/1.0/webhooks").await + } + + pub async fn delete_webhook(&self, webhook_id: &str) -> Result<()> { + self.api + .req_status( + Method::DELETE, + &format!("/api/1.0/webhooks/{}", webhook_id), + (), + ) + .await?; + Ok(()) + } + + pub async fn create_webhook( + &self, + url: &str, + events: Vec, + ) -> Result { + self.api + .post( + "/api/1.0/webhooks", + CreateWebhookRequest { + url: url.to_string(), + events, + }, + ) + .await + } +} + +impl FiatPaymentService for RevolutApi { + fn create_order( + &self, + description: &str, + amount: CurrencyAmount, + ) -> Pin> + Send>> { + let api = self.api.clone(); + let desc = description.to_string(); + Box::pin(async move { + let rsp: CreateOrderResponse = api + .post( + "/api/orders", + CreateOrderRequest { + currency: amount.0.to_string(), + amount: match amount.0 { + Currency::BTC => bail!("Bitcoin amount not allowed for fiat payments"), + Currency::EUR => (amount.1 * 100.0).floor() as u64, + Currency::USD => (amount.1 * 100.0).floor() as u64, + }, + description: Some(desc), + }, + ) + .await?; + + Ok(FiatPaymentInfo { + raw_data: serde_json::to_string(&rsp)?, + external_id: rsp.id, + }) + }) + } +} + +#[derive(Clone, Serialize)] +pub struct CreateOrderRequest { + pub amount: u64, + pub currency: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct CreateOrderResponse { + pub id: String, + pub token: String, + pub state: PaymentState, + pub created_at: DateTime, + pub updated_at: DateTime, + pub description: Option, + pub amount: u64, + pub currency: String, + pub outstanding_amount: u64, + pub checkout_url: String, +} + +#[derive(Clone, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum PaymentState { + Pending, + Processing, + Authorised, + Completed, + Cancelled, + Failed, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct RevolutWebhook { + pub id: String, + pub url: String, + pub events: Vec, + pub signing_secret: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum RevolutWebhookEvent { + OrderAuthorised, + OrderCompleted, + OrderCancelled, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct CreateWebhookRequest { + pub url: String, + pub events: Vec, +} diff --git a/src/host/libvirt.rs b/src/host/libvirt.rs index 32d54fb..322eb3d 100644 --- a/src/host/libvirt.rs +++ b/src/host/libvirt.rs @@ -34,7 +34,7 @@ impl VmHostClient for LibVirt { todo!() } - async fn configure_vm(&self, vm: &Vm) -> anyhow::Result<()> { + async fn configure_vm(&self, vm: &FullVmInfo) -> anyhow::Result<()> { todo!() } diff --git a/src/json_api.rs b/src/json_api.rs index 73dd464..856088d 100644 --- a/src/json_api.rs +++ b/src/json_api.rs @@ -1,10 +1,11 @@ use anyhow::bail; use log::debug; -use reqwest::header::{HeaderMap, AUTHORIZATION}; +use reqwest::header::{HeaderMap, ACCEPT, AUTHORIZATION}; use reqwest::{Client, Method, Url}; use serde::de::DeserializeOwned; use serde::Serialize; +#[derive(Clone)] pub struct JsonApi { pub client: Client, pub base: Url, @@ -14,6 +15,7 @@ impl JsonApi { pub fn token(base: &str, token: &str, allow_invalid_certs: bool) -> anyhow::Result { let mut headers = HeaderMap::new(); headers.insert(AUTHORIZATION, token.parse()?); + headers.insert(ACCEPT, "application/json".parse()?); let client = Client::builder() .danger_accept_invalid_certs(allow_invalid_certs) @@ -59,7 +61,6 @@ impl JsonApi { .client .request(method.clone(), self.base.join(path)?) .header("Content-Type", "application/json") - .header("Accept", "application/json") .body(body) .send() .await?; @@ -73,4 +74,31 @@ impl JsonApi { bail!("{} {}: {}: {}", method, path, status, &text); } } + + /// Make a request and only return the status code + pub async fn req_status( + &self, + method: Method, + path: &str, + body: R, + ) -> anyhow::Result { + let body = serde_json::to_string(&body)?; + debug!(">> {} {}: {}", method.clone(), path, &body); + let rsp = self + .client + .request(method.clone(), self.base.join(path)?) + .header("Content-Type", "application/json") + .body(body) + .send() + .await?; + let status = rsp.status(); + let text = rsp.text().await?; + #[cfg(debug_assertions)] + debug!("<< {}", text); + if status.is_success() { + Ok(status.as_u16()) + } else { + bail!("{} {}: {}: {}", method, path, status, &text); + } + } } diff --git a/src/lib.rs b/src/lib.rs index a1ed9af..812a588 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,14 @@ pub mod api; pub mod cors; +pub mod data_migration; pub mod dns; pub mod exchange; +pub mod fiat; pub mod host; -pub mod invoice; pub mod json_api; pub mod lightning; pub mod nip98; +pub mod payments; pub mod provisioner; pub mod router; pub mod settings; @@ -14,6 +16,6 @@ pub mod settings; pub mod ssh_client; pub mod status; pub mod worker; -pub mod data_migration; + #[cfg(test)] pub mod mocks; diff --git a/src/lightning/bitvora.rs b/src/lightning/bitvora.rs index a988c31..2a538e8 100644 --- a/src/lightning/bitvora.rs +++ b/src/lightning/bitvora.rs @@ -17,7 +17,7 @@ impl BitvoraNode { pub fn new(api_token: &str, webhook_secret: &str) -> Self { let auth = format!("Bearer {}", api_token); Self { - api: JsonApi::token("https://api.bitvora.com/", &auth).unwrap(), + api: JsonApi::token("https://api.bitvora.com/", &auth, false).unwrap(), webhook_secret: webhook_secret.to_string(), } } diff --git a/src/lightning/lnd.rs b/src/lightning/lnd.rs index 366b54e..2df7e0b 100644 --- a/src/lightning/lnd.rs +++ b/src/lightning/lnd.rs @@ -78,7 +78,6 @@ impl LightningNode for LndNode { Ok(m) => { if m.state == InvoiceState::Settled as i32 { InvoiceUpdate::Settled { - settle_index: m.settle_index, payment_hash: hex::encode(m.r_hash), } } else { diff --git a/src/lightning/mod.rs b/src/lightning/mod.rs index b56b69d..2f81f16 100644 --- a/src/lightning/mod.rs +++ b/src/lightning/mod.rs @@ -40,7 +40,6 @@ pub enum InvoiceUpdate { Error(String), Settled { payment_hash: String, - settle_index: u64, }, } diff --git a/src/mocks.rs b/src/mocks.rs index 7b4abff..9f93308 100644 --- a/src/mocks.rs +++ b/src/mocks.rs @@ -518,11 +518,18 @@ impl LNVpsDb for MockDb { .clone()) } + async fn get_vm_payment_by_ext_id(&self, id: &str) -> anyhow::Result { + let p = self.payments.lock().await; + Ok(p.iter() + .find(|p| p.external_id == Some(id.to_string())) + .context("no vm_payment")? + .clone()) + } + async fn update_vm_payment(&self, vm_payment: &VmPayment) -> anyhow::Result<()> { let mut p = self.payments.lock().await; if let Some(p) = p.iter_mut().find(|p| p.id == *vm_payment.id) { p.is_paid = vm_payment.is_paid.clone(); - p.settle_index = vm_payment.settle_index.clone(); } Ok(()) } @@ -539,7 +546,8 @@ impl LNVpsDb for MockDb { async fn last_paid_invoice(&self) -> anyhow::Result> { let p = self.payments.lock().await; Ok(p.iter() - .max_by(|a, b| a.settle_index.cmp(&b.settle_index)) + .filter(|p| p.is_paid) + .max_by(|a, b| a.created.cmp(&b.created)) .map(|v| v.clone())) } diff --git a/src/nip98.rs b/src/nip98.rs index c9959c1..c2da012 100644 --- a/src/nip98.rs +++ b/src/nip98.rs @@ -98,7 +98,7 @@ impl<'r> FromRequest<'r> for Nip98Auth { } let auth = Nip98Auth::from_base64(&auth[6..]).unwrap(); match auth.check( - request.uri().to_string().as_str(), + request.uri().path().to_string().as_str(), request.method().as_str(), ) { Ok(_) => Outcome::Success(auth), diff --git a/src/invoice.rs b/src/payments/invoice.rs similarity index 76% rename from src/invoice.rs rename to src/payments/invoice.rs index 24db1f3..09c47c8 100644 --- a/src/invoice.rs +++ b/src/payments/invoice.rs @@ -8,13 +8,13 @@ use rocket::futures::StreamExt; use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; -pub struct InvoiceHandler { +pub struct NodeInvoiceHandler { node: Arc, db: Arc, tx: UnboundedSender, } -impl InvoiceHandler { +impl NodeInvoiceHandler { pub fn new( node: Arc, db: Arc, @@ -23,9 +23,8 @@ impl InvoiceHandler { Self { node, tx, db } } - async fn mark_paid(&self, settle_index: u64, id: &Vec) -> Result<()> { - let mut p = self.db.get_vm_payment(id).await?; - p.settle_index = Some(settle_index); + async fn mark_paid(&self, id: &Vec) -> Result<()> { + let p = self.db.get_vm_payment(id).await?; self.db.vm_payment_paid(&p).await?; info!("VM payment {} for {}, paid", hex::encode(p.id), p.vm_id); @@ -47,12 +46,9 @@ impl InvoiceHandler { let mut handler = self.node.subscribe_invoices(from_ph).await?; while let Some(msg) = handler.next().await { match msg { - InvoiceUpdate::Settled { - payment_hash, - settle_index, - } => { + InvoiceUpdate::Settled { payment_hash } => { let r_hash = hex::decode(payment_hash)?; - if let Err(e) = self.mark_paid(settle_index, &r_hash).await { + if let Err(e) = self.mark_paid(&r_hash).await { error!("{}", e); } } diff --git a/src/payments/mod.rs b/src/payments/mod.rs new file mode 100644 index 0000000..3749b8a --- /dev/null +++ b/src/payments/mod.rs @@ -0,0 +1,56 @@ +use crate::lightning::LightningNode; +use crate::payments::invoice::NodeInvoiceHandler; +use crate::settings::Settings; +use crate::worker::WorkJob; +use anyhow::Result; +use lnvps_db::LNVpsDb; +use log::error; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc::UnboundedSender; +use tokio::time::sleep; + +mod invoice; +#[cfg(feature = "revolut")] +mod revolut; + +pub fn listen_all_payments( + settings: &Settings, + node: Arc, + db: Arc, + sender: UnboundedSender, +) -> Result<()> { + let mut handler = NodeInvoiceHandler::new(node.clone(), db.clone(), sender.clone()); + tokio::spawn(async move { + loop { + if let Err(e) = handler.listen().await { + error!("invoice-error: {}", e); + } + sleep(Duration::from_secs(30)).await; + } + }); + + #[cfg(feature = "revolut")] + { + + use crate::payments::revolut::RevolutPaymentHandler; + if let Some(r) = &settings.revolut { + let mut handler = RevolutPaymentHandler::new( + r.clone(), + &settings.public_url, + db.clone(), + sender.clone(), + )?; + tokio::spawn(async move { + loop { + if let Err(e) = handler.listen().await { + error!("revolut-error: {}", e); + } + sleep(Duration::from_secs(30)).await; + } + }); + } + } + + Ok(()) +} diff --git a/src/payments/revolut.rs b/src/payments/revolut.rs new file mode 100644 index 0000000..d19f967 --- /dev/null +++ b/src/payments/revolut.rs @@ -0,0 +1,129 @@ +use crate::api::{WebhookMessage, WEBHOOK_BRIDGE}; +use crate::fiat::{RevolutApi, RevolutWebhookEvent}; +use crate::settings::RevolutConfig; +use crate::worker::WorkJob; +use anyhow::{anyhow, bail, Context, Result}; +use hmac::{Hmac, Mac}; +use lnvps_db::LNVpsDb; +use log::{error, info, warn}; +use reqwest::Url; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::mpsc::UnboundedSender; + +pub struct RevolutPaymentHandler { + api: RevolutApi, + db: Arc, + sender: UnboundedSender, + public_url: String, +} + +impl RevolutPaymentHandler { + pub fn new( + settings: RevolutConfig, + public_url: &str, + db: Arc, + sender: UnboundedSender, + ) -> Result { + Ok(Self { + api: RevolutApi::new(settings)?, + public_url: public_url.to_string(), + db, + sender, + }) + } + + pub async fn listen(&mut self) -> Result<()> { + let this_webhook = Url::parse(&self.public_url)?.join("/api/v1/webhook/revolut")?; + let webhooks = self.api.list_webhooks().await?; + for wh in webhooks { + info!("Deleting old webhook: {} {}", wh.id, wh.url); + self.api.delete_webhook(&wh.id).await? + } + info!("Setting up webhook for '{}'", this_webhook); + let wh = self + .api + .create_webhook( + this_webhook.as_str(), + vec![ + RevolutWebhookEvent::OrderCompleted, + RevolutWebhookEvent::OrderAuthorised, + ], + ) + .await?; + + let secret = wh.signing_secret.context("Signing secret is missing")?; + // listen to events + let mut listenr = WEBHOOK_BRIDGE.listen(); + while let Ok(m) = listenr.recv().await { + let body: RevolutWebhook = serde_json::from_slice(m.body.as_slice())?; + info!("Received webhook {:?}", body); + if let Err(e) = verify_webhook(&secret, &m) { + error!("Signature verification failed: {}", e); + continue; + } + + if let RevolutWebhookEvent::OrderCompleted = body.event { + if let Err(e) = self.try_complete_payment(&body.order_id).await { + error!("Failed to complete order: {}", e); + } + } + } + Ok(()) + } + + async fn try_complete_payment(&self, ext_id: &str) -> Result<()> { + let p = self.db.get_vm_payment_by_ext_id(ext_id).await?; + self.db.vm_payment_paid(&p).await?; + self.sender.send(WorkJob::CheckVm { vm_id: p.vm_id })?; + info!("VM payment {} for {}, paid", hex::encode(p.id), p.vm_id); + Ok(()) + } +} + +type HmacSha256 = Hmac; +fn verify_webhook(secret: &str, msg: &WebhookMessage) -> Result<()> { + let sig = msg + .headers + .get("revolut-signature") + .ok_or_else(|| anyhow!("Missing Revolut-Signature header"))?; + let timestamp = msg + .headers + .get("revolut-request-timestamp") + .ok_or_else(|| anyhow!("Missing Revolut-Request-Timestamp header"))?; + + // check if any signatures match + for sig in sig.split(",") { + let mut sig_split = sig.split("="); + let (version, code) = ( + sig_split.next().context("Invalid signature format")?, + sig_split.next().context("Invalid signature format")?, + ); + let mut mac = HmacSha256::new_from_slice(secret.as_bytes())?; + mac.update(version.as_bytes()); + mac.update(b"."); + mac.update(timestamp.as_bytes()); + mac.update(b"."); + mac.update(msg.body.as_slice()); + let result = mac.finalize().into_bytes(); + + if hex::encode(result) == code { + return Ok(()); + } else { + warn!( + "Invalid signature found {} != {}", + code, + hex::encode(result) + ); + } + } + + bail!("No valid signature found!"); +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +struct RevolutWebhook { + pub event: RevolutWebhookEvent, + pub order_id: String, + pub merchant_order_ext_ref: Option, +} diff --git a/src/provisioner/lnvps.rs b/src/provisioner/lnvps.rs index e58b9be..2baa6b8 100644 --- a/src/provisioner/lnvps.rs +++ b/src/provisioner/lnvps.rs @@ -1,5 +1,6 @@ use crate::dns::{BasicRecord, DnsServer}; -use crate::exchange::{ExchangeRateService, Ticker}; +use crate::exchange::{Currency, CurrencyAmount, ExchangeRateService}; +use crate::fiat::FiatPaymentService; use crate::host::{get_host_client, FullVmInfo}; use crate::lightning::{AddInvoiceRequest, LightningNode}; use crate::provisioner::{ @@ -8,8 +9,11 @@ use crate::provisioner::{ use crate::router::{ArpEntry, Router}; use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings}; use anyhow::{bail, ensure, Context, Result}; -use chrono::{Days, Months, Utc}; -use lnvps_db::{DiskType, LNVpsDb, Vm, VmCostPlanIntervalType, VmCustomTemplate, VmIpAssignment, VmPayment}; +use chrono::Utc; +use lnvps_db::{ + LNVpsDb, PaymentMethod, Vm, VmCustomTemplate, VmIpAssignment, + VmPayment, +}; use log::{info, warn}; use nostr::util::hex; use std::ops::Add; @@ -28,6 +32,7 @@ pub struct LNVpsProvisioner { router: Option>, dns: Option>, + revolut: Option>, network_policy: NetworkPolicy, provisioner_config: ProvisionerConfig, @@ -46,6 +51,7 @@ impl LNVpsProvisioner { rates, router: settings.get_router().expect("router config"), dns: settings.get_dns().expect("dns config"), + revolut: settings.get_revolut().expect("revolut config"), network_policy: settings.network_policy, provisioner_config: settings.provisioner, read_only: settings.read_only, @@ -255,7 +261,9 @@ impl LNVpsProvisioner { // TODO: cache capacity somewhere let cap = HostCapacityService::new(self.db.clone()); - let host = cap.get_host_for_template(template.region_id, &template).await?; + let host = cap + .get_host_for_template(template.region_id, &template) + .await?; let pick_disk = if let Some(hd) = host.disks.first() { hd @@ -308,7 +316,9 @@ impl LNVpsProvisioner { // TODO: cache capacity somewhere let cap = HostCapacityService::new(self.db.clone()); - let host = cap.get_host_for_template(pricing.region_id, &template).await?; + let host = cap + .get_host_for_template(pricing.region_id, &template) + .await?; let pick_disk = if let Some(hd) = host.disks.first() { hd @@ -345,40 +355,83 @@ impl LNVpsProvisioner { } /// Create a renewal payment - pub async fn renew(&self, vm_id: u64) -> Result { + pub async fn renew(&self, vm_id: u64, method: PaymentMethod) -> Result { let pe = PricingEngine::new(self.db.clone(), self.rates.clone()); - let price = pe.get_vm_cost(vm_id).await?; + let price = pe.get_vm_cost(vm_id, method).await?; match price { CostResult::Existing(p) => Ok(p), CostResult::New { - msats, + amount, + currency, time_value, new_expiry, rate, } => { - const INVOICE_EXPIRE: u64 = 600; - info!("Creating invoice for {vm_id} for {} sats", msats / 1000); - let invoice = self - .node - .add_invoice(AddInvoiceRequest { - memo: Some(format!("VM renewal {vm_id} to {new_expiry}")), - amount: msats, - expire: Some(INVOICE_EXPIRE as u32), - }) - .await?; - let vm_payment = VmPayment { - id: hex::decode(invoice.payment_hash)?, - vm_id, - created: Utc::now(), - expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE)), - amount: msats, - invoice: invoice.pr, - time_value, - is_paid: false, - rate, - settle_index: None, + let desc = format!("VM renewal {vm_id} to {new_expiry}"); + let vm_payment = match method { + PaymentMethod::Lightning => { + ensure!( + currency == Currency::BTC, + "Cannot create invoices for non-BTC currency" + ); + const INVOICE_EXPIRE: u64 = 600; + info!("Creating invoice for {vm_id} for {} sats", amount / 1000); + let invoice = self + .node + .add_invoice(AddInvoiceRequest { + memo: Some(desc), + amount, + expire: Some(INVOICE_EXPIRE as u32), + }) + .await?; + VmPayment { + id: hex::decode(invoice.payment_hash)?, + vm_id, + created: Utc::now(), + expires: Utc::now().add(Duration::from_secs(INVOICE_EXPIRE)), + amount, + currency: currency.to_string(), + payment_method: method, + time_value, + is_paid: false, + rate, + external_data: invoice.pr, + external_id: None, + } + } + PaymentMethod::Revolut => { + let rev = if let Some(r) = &self.revolut { + r + } else { + bail!("Revolut not configured") + }; + ensure!( + currency != Currency::BTC, + "Cannot create revolut orders for BTC currency" + ); + let order = rev + .create_order(&desc, CurrencyAmount::from_u64(currency, amount)) + .await?; + let new_id: [u8; 32] = rand::random(); + VmPayment { + id: new_id.to_vec(), + vm_id, + created: Utc::now(), + expires: Utc::now().add(Duration::from_secs(3600)), + amount, + currency: currency.to_string(), + payment_method: method, + time_value, + is_paid: false, + rate, + external_data: order.raw_data, + external_id: Some(order.external_id), + } + } + PaymentMethod::Paypal => todo!(), }; + self.db.insert_vm_payment(&vm_payment).await?; Ok(vm_payment) @@ -443,6 +496,7 @@ mod tests { Settings { listen: None, db: "".to_string(), + public_url: "http://localhost:8000".to_string(), lightning: LightningConfig::LND { url: "".to_string(), cert: Default::default(), @@ -480,6 +534,7 @@ mod tests { reverse_zone_id: "456".to_string(), }), nostr: None, + revolut: None, } } diff --git a/src/provisioner/mod.rs b/src/provisioner/mod.rs index ba73281..24ae37c 100644 --- a/src/provisioner/mod.rs +++ b/src/provisioner/mod.rs @@ -59,4 +59,4 @@ impl Template for VmCustomTemplate { fn disk_interface(&self) -> DiskInterface { self.disk_interface } -} \ No newline at end of file +} diff --git a/src/provisioner/pricing.rs b/src/provisioner/pricing.rs index 1e87e1e..6ac04a3 100644 --- a/src/provisioner/pricing.rs +++ b/src/provisioner/pricing.rs @@ -1,8 +1,10 @@ -use crate::exchange::{Currency, ExchangeRateService, Ticker}; -use anyhow::{bail, Context, Result}; +use crate::exchange::{Currency, CurrencyAmount, ExchangeRateService, Ticker, TickerRate}; +use anyhow::{bail, Result}; use chrono::{DateTime, Days, Months, TimeDelta, Utc}; use ipnetwork::IpNetwork; -use lnvps_db::{LNVpsDb, Vm, VmCostPlan, VmCostPlanIntervalType, VmCustomTemplate, VmPayment}; +use lnvps_db::{ + LNVpsDb, PaymentMethod, Vm, VmCostPlan, VmCostPlanIntervalType, VmCustomTemplate, VmPayment, +}; use log::info; use std::ops::Add; use std::str::FromStr; @@ -28,22 +30,22 @@ impl PricingEngine { } /// Get VM cost (for renewal) - pub async fn get_vm_cost(&self, vm_id: u64) -> Result { + pub async fn get_vm_cost(&self, vm_id: u64, method: PaymentMethod) -> Result { let vm = self.db.get_vm(vm_id).await?; // Reuse existing payment until expired let payments = self.db.list_vm_payment(vm.id).await?; if let Some(px) = payments .into_iter() - .find(|p| p.expires > Utc::now() && !p.is_paid) + .find(|p| p.expires > Utc::now() && !p.is_paid && p.payment_method == method) { return Ok(CostResult::Existing(px)); } if vm.template_id.is_some() { - Ok(self.get_template_vm_cost(&vm).await?) + Ok(self.get_template_vm_cost(&vm, method).await?) } else { - Ok(self.get_custom_vm_cost(&vm).await?) + Ok(self.get_custom_vm_cost(&vm, method).await?) } } @@ -101,7 +103,7 @@ impl PricingEngine { }) } - async fn get_custom_vm_cost(&self, vm: &Vm) -> Result { + async fn get_custom_vm_cost(&self, vm: &Vm, method: PaymentMethod) -> Result { let template_id = if let Some(i) = vm.custom_template_id { i } else { @@ -114,26 +116,32 @@ impl PricingEngine { // custom templates are always 1-month intervals let time_value = (vm.expires.add(Months::new(1)) - vm.expires).num_seconds() as u64; - let (cost_msats, rate) = self.get_msats_amount(price.currency, price.total()).await?; + let (currency, amount, rate) = self + .get_amount_and_rate(CurrencyAmount(price.currency, price.total()), method) + .await?; Ok(CostResult::New { - msats: cost_msats, + amount, + currency, rate, time_value, new_expiry: vm.expires.add(TimeDelta::seconds(time_value as i64)), }) } - async fn get_msats_amount(&self, currency: Currency, amount: f32) -> Result<(u64, f32)> { + async fn get_ticker(&self, currency: Currency) -> Result { let ticker = Ticker(Currency::BTC, currency); - let rate = if let Some(r) = self.rates.get_rate(ticker).await { - r + if let Some(r) = self.rates.get_rate(ticker).await { + Ok(TickerRate(ticker, r)) } else { bail!("No exchange rate found") - }; + } + } - let cost_btc = amount / rate; + async fn get_msats_amount(&self, amount: CurrencyAmount) -> Result<(u64, f32)> { + let rate = self.get_ticker(amount.0).await?; + let cost_btc = amount.1 / rate.1; let cost_msats = (cost_btc as f64 * Self::BTC_SATS) as u64 * 1000; - Ok((cost_msats, rate)) + Ok((cost_msats, rate.1)) } fn next_template_expire(vm: &Vm, cost_plan: &VmCostPlan) -> u64 { @@ -150,7 +158,7 @@ impl PricingEngine { (next_expire - vm.expires).num_seconds() as u64 } - async fn get_template_vm_cost(&self, vm: &Vm) -> Result { + async fn get_template_vm_cost(&self, vm: &Vm, method: PaymentMethod) -> Result { let template_id = if let Some(i) = vm.template_id { i } else { @@ -159,20 +167,36 @@ impl PricingEngine { let template = self.db.get_vm_template(template_id).await?; let cost_plan = self.db.get_cost_plan(template.cost_plan_id).await?; - let (cost_msats, rate) = self - .get_msats_amount( - cost_plan.currency.parse().expect("Invalid currency"), - cost_plan.amount, - ) + let currency = cost_plan.currency.parse().expect("Invalid currency"); + let (currency, amount, rate) = self + .get_amount_and_rate(CurrencyAmount(currency, cost_plan.amount), method) .await?; - let time_value = Self::next_template_expire(&vm, &cost_plan); + let time_value = Self::next_template_expire(vm, &cost_plan); Ok(CostResult::New { - msats: cost_msats, + amount, + currency, rate, time_value, new_expiry: vm.expires.add(TimeDelta::seconds(time_value as i64)), }) } + + async fn get_amount_and_rate( + &self, + list_price: CurrencyAmount, + method: PaymentMethod, + ) -> Result<(Currency, u64, f32)> { + Ok(match (list_price.0, method) { + (c, PaymentMethod::Lightning) if c != Currency::BTC => { + let new_price = self.get_msats_amount(list_price).await?; + (Currency::BTC, new_price.0, new_price.1) + } + (cur, PaymentMethod::Revolut) if cur != Currency::BTC => { + (cur, (list_price.1 * 100.0).ceil() as u64, 1.0) + } + (c, m) => bail!("Cannot create payment for method {} and currency {}", m, c), + }) + } } #[derive(Clone)] @@ -181,8 +205,10 @@ pub enum CostResult { Existing(VmPayment), /// A new payment can be created with the specified amount New { - /// The cost in milli-sats - msats: u64, + /// The cost + amount: u64, + /// Currency + currency: Currency, /// The exchange rate used to calculate the price rate: f32, /// The time to extend the vm expiry in seconds @@ -292,14 +318,14 @@ mod tests { let db: Arc = Arc::new(db); let pe = PricingEngine::new(db.clone(), rates); - let price = pe.get_vm_cost(1).await?; + let price = pe.get_vm_cost(1, PaymentMethod::Lightning).await?; let plan = MockDb::mock_cost_plan(); match price { - CostResult::Existing(_) => bail!("??"), - CostResult::New { msats, .. } => { + CostResult::New { amount, .. } => { let expect_price = (plan.amount / MOCK_RATE * 1.0e11) as u64; - assert_eq!(expect_price, msats); + assert_eq!(expect_price, amount); } + _ => bail!("??"), } Ok(()) diff --git a/src/settings.rs b/src/settings.rs index 04540e2..440a8af 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,5 +1,6 @@ use crate::dns::DnsServer; use crate::exchange::ExchangeRateService; +use crate::fiat::FiatPaymentService; use crate::lightning::LightningNode; use crate::provisioner::LNVpsProvisioner; use crate::router::Router; @@ -12,9 +13,15 @@ use std::sync::Arc; #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct Settings { + /// Listen address for http server pub listen: Option, + + /// MYSQL connection string pub db: String, + /// Public URL mapping to this service + pub public_url: String, + /// Lightning node config for creating LN payments pub lightning: LightningConfig, @@ -42,6 +49,9 @@ pub struct Settings { /// Nostr config for sending DMs pub nostr: Option, + + /// Config for accepting revolut payments + pub revolut: Option, } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -168,6 +178,15 @@ pub struct QemuConfig { pub kvm: bool, } +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "kebab-case")] +pub struct RevolutConfig { + pub url: Option, + pub api_version: String, + pub token: String, + pub public_key: String, +} + impl Settings { pub fn get_provisioner( &self, @@ -226,4 +245,12 @@ impl Settings { } } } + + pub fn get_revolut(&self) -> Result>> { + match &self.revolut { + #[cfg(feature = "revolut")] + Some(c) => Ok(Some(Arc::new(crate::fiat::RevolutApi::new(c.clone())?))), + _ => Ok(None), + } + } }