diff --git a/crates/oxidetalis/src/extensions.rs b/crates/oxidetalis/src/extensions.rs index 7cc2c27..bc06067 100644 --- a/crates/oxidetalis/src/extensions.rs +++ b/crates/oxidetalis/src/extensions.rs @@ -142,17 +142,14 @@ impl OnlineUsersExt for OnlineUsers { } async fn disconnect_inactive_users(&self) { - let now = Utc::now().timestamp(); - let is_inactive = - |u: &SocketUserData| u.pinged_at > u.ponged_at && now - u.pinged_at.timestamp() >= 5; - self.read() - .await - .iter() - .filter(|(_, u)| is_inactive(u)) - .for_each(|(_, u)| { + self.write().await.retain(|_, u| { + // if we send ping and the client doesn't send pong + if u.pinged_at > u.ponged_at { log::info!("Disconnected from {}, inactive", u.public_key); - u.sender.close_channel() - }); - self.write().await.retain(|_, u| !is_inactive(u)); + u.sender.close_channel(); + return false; + } + true + }); } } diff --git a/crates/oxidetalis/src/main.rs b/crates/oxidetalis/src/main.rs index b770e18..a440c74 100644 --- a/crates/oxidetalis/src/main.rs +++ b/crates/oxidetalis/src/main.rs @@ -30,6 +30,7 @@ mod middlewares; mod routes; mod schemas; mod utils; +mod websocket; /// Nonce cache type, used to store nonces for a certain amount of time pub type NonceCache = Mutex>; diff --git a/crates/oxidetalis/src/routes/mod.rs b/crates/oxidetalis/src/routes/mod.rs index 47fddb8..06b9a63 100644 --- a/crates/oxidetalis/src/routes/mod.rs +++ b/crates/oxidetalis/src/routes/mod.rs @@ -25,7 +25,7 @@ use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, Remote use salvo::{catcher::Catcher, logging::Logger, prelude::*}; use crate::schemas::MessageSchema; -use crate::{middlewares, NonceCache}; +use crate::{middlewares, websocket, NonceCache}; mod user; @@ -141,6 +141,7 @@ pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service { let router = Router::new() .push(Router::with_path("user").push(user::route())) + .push(Router::with_path("ws").push(websocket::route())) .hoop(middlewares::add_server_headers) .hoop(Logger::new()) .hoop( diff --git a/crates/oxidetalis/src/websocket/errors.rs b/crates/oxidetalis/src/websocket/errors.rs new file mode 100644 index 0000000..b05e550 --- /dev/null +++ b/crates/oxidetalis/src/websocket/errors.rs @@ -0,0 +1,57 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Websocket errors + +/// Result type of websocket +pub type WsResult = Result; + +/// Websocket errors, returned in the websocket communication +#[derive(Debug)] +pub enum WsError { + /// The signature is invalid + InvalidSignature, + /// Message type must be text + NotTextMessage, + /// Invalid json data + InvalidJsonData, + /// Unknown client event + UnknownClientEvent, +} + +impl WsError { + /// Returns error name + pub const fn name(&self) -> &'static str { + match self { + WsError::InvalidSignature => "InvalidSignature", + WsError::NotTextMessage => "NotTextMessage", + WsError::InvalidJsonData => "InvalidJsonData", + WsError::UnknownClientEvent => "UnknownClientEvent", + } + } + + /// Returns the error reason + pub const fn reason(&self) -> &'static str { + match self { + WsError::InvalidSignature => "Invalid event signature", + WsError::NotTextMessage => "The websocket message must be text message", + WsError::InvalidJsonData => "Received invalid json data, the text must be valid json", + WsError::UnknownClientEvent => { + "Unknown client event, the event is not recognized by the server" + } + } + } +} diff --git a/crates/oxidetalis/src/websocket/events/client.rs b/crates/oxidetalis/src/websocket/events/client.rs new file mode 100644 index 0000000..25e688b --- /dev/null +++ b/crates/oxidetalis/src/websocket/events/client.rs @@ -0,0 +1,61 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Events that the client send it + +use oxidetalis_core::types::Signature; +use serde::{Deserialize, Serialize}; + +use crate::{utils, NonceCache}; + +/// Client websocket event +#[derive(Deserialize, Clone, Debug)] +pub struct ClientEvent { + pub event: ClientEventType, + signature: Signature, +} + +/// Client websocket event type +#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] +#[serde(rename_all = "PascalCase", tag = "event", content = "data")] +pub enum ClientEventType { + /// Ping event + Ping { timestamp: u64 }, + /// Pong event + Pong { timestamp: u64 }, +} + +impl ClientEventType { + /// Returns event data as json bytes + pub fn data(&self) -> Vec { + serde_json::to_value(self).expect("can't fail")["data"] + .to_string() + .into_bytes() + } +} + +impl ClientEvent { + /// Verify the signature of the event + pub fn verify_signature( + &self, + shared_secret: &[u8; 32], + nonce_cache: &NonceCache, + nonce_limit: &usize, + ) -> bool { + utils::is_valid_nonce(&self.signature, nonce_cache, nonce_limit) + && self.signature.verify(&self.event.data(), shared_secret) + } +} diff --git a/crates/oxidetalis/src/websocket/events/mod.rs b/crates/oxidetalis/src/websocket/events/mod.rs new file mode 100644 index 0000000..6690d70 --- /dev/null +++ b/crates/oxidetalis/src/websocket/events/mod.rs @@ -0,0 +1,23 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Server and client websocket events + +mod client; +mod server; + +pub use client::*; +pub use server::*; diff --git a/crates/oxidetalis/src/websocket/events/server.rs b/crates/oxidetalis/src/websocket/events/server.rs new file mode 100644 index 0000000..225b129 --- /dev/null +++ b/crates/oxidetalis/src/websocket/events/server.rs @@ -0,0 +1,117 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Events that the server send it + +use std::marker::PhantomData; + +use chrono::Utc; +use oxidetalis_core::{cipher::K256Secret, types::Signature}; +use salvo::websocket::Message; +use serde::Serialize; + +use crate::websocket::errors::WsError; + +/// Signed marker, used to indicate that the event is signed +pub struct Signed; +/// Unsigned marker, used to indicate that the event is unsigned +pub struct Unsigned; + +/// Server websocket event +#[derive(Serialize, Clone, Debug)] +pub struct ServerEvent { + #[serde(flatten)] + event: ServerEventType, + signature: Signature, + #[serde(skip)] + phantom: PhantomData, +} + +/// server websocket event type +#[derive(Serialize, Clone, Eq, PartialEq, Debug)] +#[serde(rename_all = "PascalCase")] +pub enum ServerEventType { + /// Ping event + Ping { timestamp: u64 }, + /// Pong event + Pong { timestamp: u64 }, + /// Error event + Error { + name: &'static str, + reason: &'static str, + }, +} + +impl ServerEventType { + /// Returns event data as json bytes + pub fn data(&self) -> Vec { + serde_json::to_value(self).expect("can't fail")["data"] + .to_string() + .into_bytes() + } +} + +impl ServerEvent { + /// Creates new [`ServerEvent`] + pub fn new(event: ServerEventType) -> Self { + Self { + event, + signature: Signature::from([0u8; 56]), + phantom: PhantomData, + } + } + + /// Creates ping event + pub fn ping() -> Self { + Self::new(ServerEventType::Ping { + timestamp: Utc::now().timestamp() as u64, + }) + } + + /// Creates pong event + pub fn pong() -> Self { + Self::new(ServerEventType::Pong { + timestamp: Utc::now().timestamp() as u64, + }) + } + + /// Sign the event + pub fn sign(self, shared_secret: &[u8; 32]) -> ServerEvent { + ServerEvent:: { + signature: K256Secret::sign_with_shared_secret( + &serde_json::to_vec(&self.event.data()).expect("Can't fail"), + shared_secret, + ), + event: self.event, + phantom: PhantomData, + } + } +} + +impl From<&ServerEvent> for Message { + fn from(value: &ServerEvent) -> Self { + Message::text(serde_json::to_string(value).expect("This can't fail")) + } +} + +impl From for ServerEvent { + fn from(err: WsError) -> Self { + ServerEvent::new(ServerEventType::Error { + name: err.name(), + reason: err.reason(), + }) + } +} diff --git a/crates/oxidetalis/src/websocket/mod.rs b/crates/oxidetalis/src/websocket/mod.rs new file mode 100644 index 0000000..4f8aa6a --- /dev/null +++ b/crates/oxidetalis/src/websocket/mod.rs @@ -0,0 +1,222 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use chrono::Utc; +use errors::{WsError, WsResult}; +use futures::{channel::mpsc, FutureExt, StreamExt, TryStreamExt}; +use once_cell::sync::Lazy; +use oxidetalis_core::{cipher::K256Secret, types::PublicKey}; +use salvo::{ + handler, + http::StatusError, + websocket::{Message, WebSocket, WebSocketUpgrade}, + Depot, + Request, + Response, + Router, +}; +use tokio::{sync::RwLock, task::spawn as tokio_spawn, time::sleep as tokio_sleep}; + +mod errors; +mod events; + +pub use events::*; +use uuid::Uuid; + +use crate::{ + extensions::{DepotExt, OnlineUsersExt}, + middlewares, + utils, + NonceCache, +}; + +/// Online users type +pub type OnlineUsers = RwLock>; + +/// List of online users, users that are connected to the server +// FIXME: Use `std::sync::LazyLock` after it becomes stable in `1.80.0` +static ONLINE_USERS: Lazy = Lazy::new(OnlineUsers::default); + +/// A user connected to the server +pub struct SocketUserData { + /// Sender to send messages to the user + pub sender: mpsc::UnboundedSender>, + /// User public key + pub public_key: PublicKey, + /// Time that the user pinged at + pub pinged_at: chrono::DateTime, + /// Time that the user ponged at + pub ponged_at: chrono::DateTime, + /// User shared secret + pub shared_secret: [u8; 32], +} + +impl SocketUserData { + /// Creates new [`SocketUserData`] + pub fn new( + public_key: PublicKey, + shared_secret: [u8; 32], + sender: mpsc::UnboundedSender>, + ) -> Self { + let now = Utc::now(); + Self { + sender, + public_key, + shared_secret, + pinged_at: now, + ponged_at: now, + } + } +} + +/// WebSocket handler, that handles the user connection +#[handler] +pub async fn user_connected( + req: &mut Request, + res: &mut Response, + depot: &Depot, +) -> Result<(), StatusError> { + let nonce_cache = depot.nonce_cache(); + let nonce_limit = *depot.nonce_cache_size(); + let public_key = + utils::extract_public_key(req).expect("The public key was checked in the middleware"); + // FIXME: The config should hold `K256Secret` not `PrivateKey` + let shared_secret = + K256Secret::from_privkey(&depot.config().server.private_key).shared_secret(&public_key); + + WebSocketUpgrade::new() + .upgrade(req, res, move |ws| { + handle_socket(ws, nonce_cache, nonce_limit, public_key, shared_secret) + }) + .await +} + +/// Handle the websocket connection +async fn handle_socket( + ws: WebSocket, + nonce_cache: Arc, + nonce_limit: usize, + user_public_key: PublicKey, + user_shared_secret: [u8; 32], +) { + let (user_ws_sender, mut user_ws_receiver) = ws.split(); + + let (sender, receiver) = mpsc::unbounded(); + let receiver = receiver.into_stream(); + let fut = receiver.forward(user_ws_sender).map(|result| { + if let Err(err) = result { + log::error!("websocket send error: {err}"); + } + }); + tokio_spawn(fut); + let conn_id = Uuid::new_v4(); + let user = SocketUserData::new(user_public_key, user_shared_secret, sender.clone()); + ONLINE_USERS.add_user(&conn_id, user).await; + log::info!("New user connected: ConnId(={conn_id}) PublicKey(={user_public_key})"); + + let fut = async move { + while let Some(Ok(msg)) = user_ws_receiver.next().await { + match handle_ws_msg(msg, &nonce_cache, &nonce_limit, &user_shared_secret) { + Ok(event) => { + if let Some(server_event) = handle_events(event, &conn_id).await { + if let Err(err) = sender.unbounded_send(Ok(Message::from( + &server_event.sign(&user_shared_secret), + ))) { + log::error!("Websocket Error: {err}"); + break; + } + }; + } + Err(err) => { + if let Err(err) = sender.unbounded_send(Ok(Message::from( + &ServerEvent::from(err).sign(&user_shared_secret), + ))) { + log::error!("Websocket Error: {err}"); + break; + }; + } + }; + } + user_disconnected(&conn_id, &user_public_key).await; + }; + tokio_spawn(fut); +} + +/// Handle websocket msg +fn handle_ws_msg( + msg: Message, + nonce_cache: &NonceCache, + nonce_limit: &usize, + shared_secret: &[u8; 32], +) -> WsResult { + let Ok(text) = msg.to_str() else { + return Err(WsError::NotTextMessage); + }; + let event = serde_json::from_str::(text).map_err(|err| { + if err.is_data() { + WsError::UnknownClientEvent + } else { + WsError::InvalidJsonData + } + })?; + if !event.verify_signature(shared_secret, nonce_cache, nonce_limit) { + return Err(WsError::InvalidSignature); + } + Ok(event) +} + +/// Handle user events, and return the server event if needed +async fn handle_events(event: ClientEvent, conn_id: &Uuid) -> Option> { + match &event.event { + ClientEventType::Ping { .. } => Some(ServerEvent::pong()), + ClientEventType::Pong { .. } => { + ONLINE_USERS.update_pong(conn_id).await; + None + } + } +} + +/// Handle user disconnected +async fn user_disconnected(conn_id: &Uuid, public_key: &PublicKey) { + ONLINE_USERS.remove_user(conn_id).await; + log::debug!("User disconnect: ConnId(={conn_id}) PublicKey(={public_key})"); +} + +pub fn route() -> Router { + let users_pinger = async { + /// Seconds to wait for pongs, before disconnecting the user + const WAIT_FOR_PONGS_SECS: u32 = 10; + /// Seconds to sleep between pings (10 minutes) + const SLEEP_SECS: u32 = 60 * 10; + loop { + log::debug!("Start pinging online users"); + ONLINE_USERS.ping_all().await; + tokio_sleep(Duration::from_secs(u64::from(WAIT_FOR_PONGS_SECS))).await; + ONLINE_USERS.disconnect_inactive_users().await; + log::debug!("Done pinging online users and disconnected inactive ones"); + tokio_sleep(Duration::from_secs(u64::from(SLEEP_SECS))).await; + } + }; + + tokio_spawn(users_pinger); + + Router::new() + .push(Router::with_path("chat").get(user_connected)) + .hoop(middlewares::signature_check) + .hoop(middlewares::public_key_check) +}