diff --git a/Cargo.lock b/Cargo.lock index a06e2d5..4fb9635 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -947,18 +947,18 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -1072,18 +1072,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.189" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" +checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.189" +version = "1.0.195" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" +checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" dependencies = [ "proc-macro2", "quote", @@ -1102,9 +1102,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.107" +version = "1.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4" dependencies = [ "itoa", "ryu", @@ -1221,9 +1221,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "2.0.38" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -1612,6 +1612,7 @@ dependencies = [ "regex", "serde", "serde_ini", + "serde_json", "serde_with", "tokio", "url", diff --git a/Cargo.toml b/Cargo.toml index 5f1a1c2..1d6c958 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ serde_with = "3.5.0" urlencoding = "2.1.3" log = "0.4.20" env_logger = "0.11.0" +serde_json = "1.0.111" [build-dependencies] chrono = "0.4.31" diff --git a/src/event.rs b/src/event.rs index 23dea6a..0843755 100644 --- a/src/event.rs +++ b/src/event.rs @@ -9,23 +9,24 @@ type WgLink = String; type WgPeer = String; #[derive(Clone, Debug, Serialize, Deserialize)] -struct EventInfo { - link: WgLink, - cidr: IpInet, +pub struct EventInfo { + pub link: WgLink, + pub cidr: IpInet, } #[derive(Clone, Debug, Serialize, Deserialize)] -enum Event { +pub enum Event { Add(EventInfo), Del(EventInfo), } -struct EventParseError(T); +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EventParseError(T); impl FromStr for Event { - type Err = EventParseError<&'static str>; + type Err = EventParseError; fn from_str(s: &str) -> Result { - let re = regex::Regex::new(r"^(Deleted )?\d+: ([a-z0-9]+)\s+inet ((\d+\.?){4}(/\d+)?).*$") + let re = regex::Regex::new(r"^(Deleted )?\d+: ([a-z0-9]+)\s+inet ((\d+\.?){4}(/\d+)?).*\n?$") .unwrap(); if let Some(captures) = re.captures(s) { @@ -40,28 +41,18 @@ impl FromStr for Event { } } else { Err(Self::Err { - 0: "Line couldn't be parsed as event", + 0: format!("'{}' couldn't be parsed as event", s), }) } + } +} - //if re_send_add.is_match(s) { - // send_add = true; - //} else if re_send_del.is_match(s) { - // send_del = true; - //} - //if send_add | send_del { - // cidr_ = re_extract_ip.captures(s).unwrap().get(1).unwrap().as_str(); - // public_key = send_daemon_get_public_key(&iface).await; - //} else { - // return Err; - //} - //if send_add { - // let request = hyper::Request::builder() - // .method(hyper::Method::PATCH) - // .uri(format!("http://{peer}:port/wireguard/{iface}/peer/{public_key}/address", peer=cidr_, iface=iface, public_key=public_key)) - // .body(hyper::Body::empty()).unwrap(); - // hyper::Client::new().request(request).await.unwrap(); - //} +impl Event { + pub fn inner(&self) -> &EventInfo { + match self { + Self::Add(e) => e, + Self::Del(e) => e, + } } } diff --git a/src/helpers.rs b/src/helpers.rs index c77e684..40e38cf 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -1,5 +1,7 @@ use std::process::Output; use tokio::process::Command; +use warp::{Filter, Rejection}; +use std::str::FromStr; pub async fn command_output(command: &mut Command) -> std::io::Result { log::debug!("Shell command {:?}", command); @@ -7,3 +9,24 @@ pub async fn command_output(command: &mut Command) -> std::io::Result { log::debug!("Output: {:?}", output); output } + +async fn param_handler(s: String) -> Result +where + ::Err: std::fmt::Debug, +{ + let new_s = match urlencoding::decode(&s) { + Ok(v) => v, + Err(_e) => return Err(warp::reject::not_found()), + }; + let t = match T::from_str(&new_s) { + Ok(v) => v, + Err(_e) => return Err(warp::reject::not_found()), + }; + Ok(t) +} +pub fn param() -> impl Filter + Copy +where + ::Err: std::fmt::Debug, +{ + warp::path::param::().and_then(param_handler) +} diff --git a/src/main.rs b/src/main.rs index 67ba3cb..072b779 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ use clap::Parser; use helpers::command_output; -use std::str::FromStr; +use std::{str::FromStr, collections::HashMap}; use tokio::io::{self, AsyncBufReadExt, BufReader}; use warp::Filter; use wg::Peer; mod event; mod helpers; +mod model; mod wg; type WgLink = String; @@ -24,6 +25,9 @@ struct CliArguments { #[derive(Debug)] struct RejectCommandFailedToExecute; impl warp::reject::Reject for RejectCommandFailedToExecute {} +#[derive(Debug)] +struct RejectBadRequest; +impl warp::reject::Reject for RejectBadRequest {} #[derive(Debug)] struct CommandError { @@ -56,6 +60,64 @@ fn split_allowed_ips_text(text: &str, peer: WgPeer) -> (WgPeer, Vec Result { + let peer = message.peer; + let event = message.event; + log::debug!("Received event {:?} from peer {:?}", &event, &peer); + let mut new_peer: Option<&wg::Peer> = None; + for peer_ in config.peers.iter() { + if peer_.public_key == peer { + new_peer = Some(peer_); + break; + } + } + if new_peer.is_none() { + return Err(warp::reject::custom(RejectBadRequest)); + } + let new_peer = new_peer.unwrap(); + let iter = new_peer.allowed_ips.iter(); + let e_info = match event.clone() { + event::Event::Add(e) => e, + event::Event::Del(e) => e, + }; + let iter = match event { + event::Event::Add(_) => iter.chain(std::iter::once(&e_info.cidr)).collect::>(), + _ => iter.collect::>(), + }; + let ips_str = + iter.iter() + .map(|x| x.to_string()) + .fold("".to_string(), |acc, x| { + let mut s = String::from(acc); + s.push_str(&x); + s.push_str(","); + s + }); + let ips_str = &ips_str[0..ips_str.len() - 1]; + + let mut command = tokio::process::Command::new("wg"); + command + .arg("set") + .arg(&e_info.link) + .arg("peer") + .arg(&peer) + .arg("allowed-ips") + .arg(ips_str); + let output = match command_output(&mut command).await { + Ok(v) => v, + Err(_) => return Err(warp::reject::custom(RejectCommandFailedToExecute)), + }; + + if output.status.success() { + Ok("") + } else { + Err(warp::reject::custom(CommandError::from(output))) + } +} + async fn wg_add_address( link: WgLink, peer: WgPeer, @@ -163,18 +225,14 @@ async fn wg_del_address( async fn send_daemon_get_public_key(iface: &WgLink) -> WgPeer { let mut command = tokio::process::Command::new("wg"); command.arg("show").arg(&iface).arg("public-key"); - println!("command = {:?}", command); - let output = command.output().await.unwrap(); + let output = helpers::command_output(&mut command).await.unwrap(); std::str::from_utf8(output.stdout.as_ref()) .unwrap() .trim_end_matches("\n") .to_string() } -async fn send_daemon(iface: WgLink, port: u16, peers: Vec) -> () { - let re_send_add = regex::Regex::new(r"^\d+:.*\n$").unwrap(); - let re_send_del = regex::Regex::new(r"^Deleted \d+:.*\n$").unwrap(); - let re_extract_ip = regex::Regex::new(r"^.*inet ((\d+\.?){4}(/\d+)?).*\n$").unwrap(); +async fn event_generator(iface: WgLink, tx: tokio::sync::broadcast::Sender) -> () { let mut command = tokio::process::Command::new("ip"); command.arg("monitor").arg("address").arg("dev").arg(&iface); command.stdout(std::process::Stdio::piped()); @@ -191,11 +249,6 @@ async fn send_daemon(iface: WgLink, port: u16, peers: Vec) -> () { let mut enable_stdout = true; let mut enable_stderr = true; loop { - let mut send_add = false; - let mut send_del = false; - - let cidr_: String; - let public_key; let b: String; if enable_stdout & enable_stderr { @@ -243,59 +296,74 @@ async fn send_daemon(iface: WgLink, port: u16, peers: Vec) -> () { break; } - if re_send_add.is_match(&b) { - send_add = true; - } else if re_send_del.is_match(&b) { - send_del = true; + let e = match event::Event::from_str(&b) { + Ok(e) => e, + Err(_) => { continue; }, + }; + tx.send(e).unwrap(); + } +} + +async fn event_sender( + config: wg::Wg, + addr: std::net::SocketAddr, + rx: tokio::sync::broadcast::Receiver, +) -> () { + let mut delayed: HashMap> = HashMap::new(); + let mut rx = rx; + for peer in config.peers.iter() { + delayed.insert(peer.clone(), vec![]); + } + let port = addr.port(); + let client = hyper::Client::new(); + loop { + let e: event::Event = rx.recv().await.unwrap(); + let public_key = send_daemon_get_public_key(&e.inner().link).await; + + let mut join_set = tokio::task::JoinSet::new(); + join_set.spawn(async {tokio::time::sleep(std::time::Duration::from_millis(500)).await; 0}); + for (i, peer) in config.peers.iter().enumerate() { + //let u = format!( + // "http://{peer}:{port}/wireguard/{iface}/peer/{public_key}/address", + // peer=peer.allowed_ips.get(0).unwrap(), + // port=port, + // iface=encoded_iface, + // public_key=encoded_public_key, + //); + //let request = hyper::Request::builder() + // .method(&method) + // .uri(&u) + // .body(hyper::Body::from(encoded_cidr.to_string())) + // .unwrap(); + let u = format!( + "http://{peer}:{port}/ipc", + peer=peer.allowed_ips.get(0).unwrap(), + port=port, + ); + let request = hyper::Request::builder() + .method(&hyper::Method::PATCH) + .uri(&u) + .body(hyper::Body::from(serde_json::to_string(&model::IpcMessage { + peer: public_key.clone(), + event: e.clone(), + }).unwrap())) + .unwrap(); + let a_client = client.clone(); + join_set.spawn(async move { + log::debug!("Request: {:?}", &request); + let r = a_client.request(request).await; + log::debug!("Response: {:?}", &r); + i+1 + }); } - if send_add | send_del { - cidr_ = - urlencoding::encode(re_extract_ip.captures(&b).unwrap().get(1).unwrap().as_str()) - .into(); - public_key = send_daemon_get_public_key(&iface).await; - } else { - continue; - } - if send_add { - for peer in peers.iter() { - let peer_ip = peer.allowed_ips.get(0).unwrap(); - let u = format!( - "http://{peer}:{port}/wireguard/{iface}/peer/{public_key}/address", - peer = peer_ip, - port = port, - iface = iface, - public_key = public_key - ); - println!("{:?}", u); - let request = hyper::Request::builder() - .method(hyper::Method::POST) - .uri(&u) - .body(hyper::Body::from(cidr_.clone())) - .unwrap(); - let r = hyper::Client::new().request(request).await; - println!("{:?}", r); - } - } else if send_del { - for peer in peers.iter() { - let peer_ip = peer.allowed_ips.get(0).unwrap(); - let u = format!( - "http://{peer}:{port}/wireguard/{iface}/peer/{public_key}/address/{cidr}", - peer = peer_ip, - port = port, - iface = iface, - public_key = public_key, - cidr = &cidr_ - ); - println!("{:?}", u); - let request = hyper::Request::builder() - .method(hyper::Method::DELETE) - .uri(&u) - .body(hyper::Body::empty()) - .unwrap(); - let r = hyper::Client::new().request(request).await; - println!("{:?}", r); + while let Some(joined) = join_set.join_next().await { + let idx = joined.unwrap(); + if idx == 0 { + log::debug!("Failed to send update to all peers, aborting remaining tasks"); + break; } } + join_set.shutdown().await; } } @@ -311,35 +379,20 @@ async fn main() { let config = wg::Wg::from_file(&args.link).await; let alive = warp::path!("alive").and(warp::get()).map(|| BUILD_INFO); - let base = warp::path("wireguard") - .and(warp::path::param::()) - .and(warp::path("peer")) - .and(warp::path::param::()) - .and(warp::path("address")); - let address_add = base - .and(warp::path::end()) - .and(warp::post()) - .and(warp::body::bytes().map(|b: bytes::Bytes| { - cidr::IpInet::from_str( - urlencoding::decode(std::str::from_utf8(b.as_ref()).unwrap().trim()) - .unwrap() - .to_string() - .as_str(), - ) - .unwrap() - })) - .and_then(wg_add_address); - let address_del = base - .and(warp::path::param::()) - .and(warp::path::end()) - .and(warp::delete()) - .and_then(wg_del_address); - let routes = alive.or(address_add.or(address_del)); + let closure_config = config.clone(); + let ipc = warp::path!("ipc") + .and(warp::patch()) + .map(move || closure_config.clone()) + .and(warp::body::json::()) + .and_then(wg_modify_address); + let routes = alive.or(ipc); log::info!("{}", BUILD_INFO); + let (tx, rx1) = tokio::sync::broadcast::channel(16); tokio::select!( _ = warp::serve(routes).run(args.addr) => {}, - _ = send_daemon(args.link, args.addr.port(), config.peers) => {}, + _ = event_generator(args.link, tx) => {}, + _ = event_sender(config, args.addr, rx1) => {}, ); } diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..89497bf --- /dev/null +++ b/src/model.rs @@ -0,0 +1,14 @@ +use std::str::FromStr; +use std::sync::Arc; + +use cidr::IpInet; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; + +use crate::event::Event; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IpcMessage { + pub peer: crate::WgPeer, + pub event: Event, +} diff --git a/src/wg.rs b/src/wg.rs index ad492b5..6658e7c 100644 --- a/src/wg.rs +++ b/src/wg.rs @@ -58,7 +58,7 @@ where return Ok((host, port)); } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[serde_as] #[derive(Serialize, Deserialize)] #[serde(rename_all = "PascalCase")]