diff --git a/Cargo.lock b/Cargo.lock index 3e48080..8c8c1fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -626,6 +626,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -958,6 +968,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1895,19 +1906,23 @@ version = "0.1.0" dependencies = [ "chrono", "derive-new", + "futures", "log", "logcall", + "once_cell", "oxidetalis_config", "oxidetalis_core", "oxidetalis_entities", "oxidetalis_migrations", "pretty_env_logger", + "rayon", "salvo", "sea-orm", "serde", "serde_json", "thiserror", "tokio", + "uuid", ] [[package]] @@ -2258,6 +2273,26 @@ dependencies = [ "bitflags 2.6.0", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/crates/oxidetalis/Cargo.toml b/crates/oxidetalis/Cargo.toml index 7a59970..f3a5764 100644 --- a/crates/oxidetalis/Cargo.toml +++ b/crates/oxidetalis/Cargo.toml @@ -23,9 +23,13 @@ thiserror = { workspace = true } chrono = { workspace = true } salvo = { version = "0.68.2", features = ["rustls", "affix", "logging", "oapi", "rate-limiter", "websocket"] } tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } +uuid = { version = "1.9.1", default-features = false, features = ["v4"] } derive-new = "0.6.0" pretty_env_logger = "0.5.0" serde_json = "1.0.117" +once_cell = "1.19.0" +futures = "0.3.30" +rayon = "1.10.0" [lints.rust] unsafe_code = "deny" diff --git a/crates/oxidetalis/src/extensions.rs b/crates/oxidetalis/src/extensions.rs index 9e6b2e0..3f4415e 100644 --- a/crates/oxidetalis/src/extensions.rs +++ b/crates/oxidetalis/src/extensions.rs @@ -18,10 +18,15 @@ use std::sync::Arc; use chrono::Utc; use oxidetalis_config::Config; -use salvo::Depot; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; +use salvo::{websocket::Message, Depot}; use sea_orm::DatabaseConnection; +use uuid::Uuid; -use crate::{routes::DEPOT_NONCE_CACHE_SIZE, NonceCache}; +use crate::{ + nonce::NonceCache, + websocket::{OnlineUsers, ServerEvent, SocketUserData}, +}; /// Extension trait for the Depot. pub trait DepotExt { @@ -31,15 +36,24 @@ pub trait DepotExt { fn config(&self) -> &Config; /// Retutns the nonce cache fn nonce_cache(&self) -> Arc; - /// Returns the size of the nonce cache - fn nonce_cache_size(&self) -> &usize; } -/// Extension trait for the nonce cache. -pub trait NonceCacheExt { - /// Add a nonce to the cache, returns `true` if the nonce is added, `false` - /// if the nonce is already exist in the cache. - fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool; +/// Extension trait for online websocket users +pub trait OnlineUsersExt { + /// Add new user to the online users + async fn add_user(&self, conn_id: &Uuid, data: SocketUserData); + + /// Remove user from online users + async fn remove_user(&self, conn_id: &Uuid); + + /// Ping all online users + async fn ping_all(&self); + + /// Update user pong at time + async fn update_pong(&self, conn_id: &Uuid); + + /// Disconnect inactive users (who not respond for the ping event) + async fn disconnect_inactive_users(&self); } impl DepotExt for Depot { @@ -58,36 +72,42 @@ impl DepotExt for Depot { .expect("Nonce cache not found"), ) } - - fn nonce_cache_size(&self) -> &usize { - let s: &Arc = self - .get(DEPOT_NONCE_CACHE_SIZE) - .expect("Nonce cache size not found"); - s.as_ref() - } } -impl NonceCacheExt for &NonceCache { - fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool { - let mut cache = self.lock().expect("Nonce cache lock poisoned, aborting..."); - let now = Utc::now().timestamp(); - cache.retain(|_, time| (now - *time) < 30); +impl OnlineUsersExt for OnlineUsers { + async fn add_user(&self, conn_id: &Uuid, data: SocketUserData) { + self.write().await.insert(*conn_id, data); + } - if &cache.len() >= limit { - log::warn!("Nonce cache limit reached, clearing 10% of the cache"); - let num_to_remove = limit / 10; - let keys: Vec<[u8; 16]> = cache.keys().copied().collect(); - for key in keys.iter().take(num_to_remove) { - cache.remove(key); + async fn remove_user(&self, conn_id: &Uuid) { + self.write().await.remove(conn_id); + } + + async fn ping_all(&self) { + let now = Utc::now(); + self.write().await.par_iter_mut().for_each(|(_, u)| { + u.pinged_at = now; + let _ = u.sender.unbounded_send(Ok(Message::from( + &ServerEvent::ping().sign(&u.shared_secret), + ))); + }); + } + + async fn update_pong(&self, conn_id: &Uuid) { + if let Some(user) = self.write().await.get_mut(conn_id) { + user.ponged_at = Utc::now() + } + } + + async fn disconnect_inactive_users(&self) { + self.write().await.retain(|_, u| { + // if we send ping and the client doesn't send pong + if u.pinged_at > u.ponged_at { + log::info!("Disconnected from {}, inactive", u.public_key); + u.sender.close_channel(); + return false; } - } - - // We can use insert directly, but it's will update the value if the key is - // already exist so we need to check if the key is already exist or not - if cache.contains_key(nonce) { - return false; - } - cache.insert(*nonce, now); - true + true + }); } } diff --git a/crates/oxidetalis/src/main.rs b/crates/oxidetalis/src/main.rs index b770e18..53ad66d 100644 --- a/crates/oxidetalis/src/main.rs +++ b/crates/oxidetalis/src/main.rs @@ -17,7 +17,7 @@ #![doc = include_str!("../../../README.md")] #![warn(missing_docs, unsafe_code)] -use std::{collections::HashMap, process::ExitCode, sync::Mutex}; +use std::process::ExitCode; use oxidetalis_config::{CliArgs, Parser}; use oxidetalis_migrations::MigratorTrait; @@ -27,12 +27,11 @@ mod database; mod errors; mod extensions; mod middlewares; +mod nonce; mod routes; mod schemas; mod utils; - -/// Nonce cache type, used to store nonces for a certain amount of time -pub type NonceCache = Mutex>; +mod websocket; async fn try_main() -> errors::Result<()> { pretty_env_logger::init_timed(); @@ -49,6 +48,7 @@ async fn try_main() -> errors::Result<()> { let local_addr = format!("{}:{}", config.server.host, config.server.port); let acceptor = TcpListener::new(&local_addr).bind().await; log::info!("Server listening on http://{local_addr}"); + log::info!("Chat websocket on ws://{local_addr}/ws/chat"); if config.openapi.enable { log::info!( "The openapi schema is available at http://{local_addr}{}", diff --git a/crates/oxidetalis/src/middlewares/signature.rs b/crates/oxidetalis/src/middlewares/signature.rs index 463a1fc..c94d0de 100644 --- a/crates/oxidetalis/src/middlewares/signature.rs +++ b/crates/oxidetalis/src/middlewares/signature.rs @@ -70,7 +70,7 @@ pub async fn signature_check( } }; - if !utils::is_valid_nonce(&signature, &depot.nonce_cache(), depot.nonce_cache_size()) + if !utils::is_valid_nonce(&signature, &depot.nonce_cache()).await || !utils::is_valid_signature( &sender_public_key, &depot.config().server.private_key, diff --git a/crates/oxidetalis/src/nonce.rs b/crates/oxidetalis/src/nonce.rs new file mode 100644 index 0000000..75b247d --- /dev/null +++ b/crates/oxidetalis/src/nonce.rs @@ -0,0 +1,73 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Nonce cache implementation + +use std::{collections::HashMap, mem}; + +use chrono::Utc; +use oxidetalis_core::types::Size; +use tokio::sync::Mutex as TokioMutex; + +/// Size of each entry in the nonce cache +pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::(); +/// Size of the hashmap itself without the entrys (48 bytes) +pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::>(); + +/// Nonce cache struct, used to store nonces for a short period of time +/// to prevent replay attacks, each nonce has a 30 seconds lifetime. +/// +/// The cache will remove first 10% nonces if the cache limit is reached. +pub struct NonceCache { + /// The nonce cache hashmap, the key is the nonce and the value is the time + cache: TokioMutex>, +} + +impl NonceCache { + /// Creates new [`NonceCache`] instance, with the given cache limit + pub fn new(cache_limit: &Size) -> Self { + Self { + cache: TokioMutex::new(HashMap::with_capacity( + (cache_limit.as_bytes() - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE, + )), + } + } + + /// Add a nonce to the cache, returns `true` if the nonce is added, `false` + /// if the nonce is already exist in the cache. + pub async fn add_nonce(&self, nonce: &[u8; 16]) -> bool { + let mut cache = self.cache.lock().await; + let now = Utc::now().timestamp(); + cache.retain(|_, time| (now - *time) < 30); + + if cache.len() == cache.capacity() { + log::warn!("Nonce cache limit reached, clearing 10% of the cache"); + let num_to_remove = cache.capacity() / 10; + let keys: Vec<[u8; 16]> = cache.keys().copied().collect(); + for key in keys.iter().take(num_to_remove) { + cache.remove(key); + } + } + + // We can use insert directly, but it's will update the value if the key is + // already exist so we need to check if the key is already exist or not + if cache.contains_key(nonce) { + return false; + } + cache.insert(*nonce, now); + true + } +} diff --git a/crates/oxidetalis/src/routes/mod.rs b/crates/oxidetalis/src/routes/mod.rs index 47fddb8..45bdbd5 100644 --- a/crates/oxidetalis/src/routes/mod.rs +++ b/crates/oxidetalis/src/routes/mod.rs @@ -14,9 +14,8 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; -use std::{env, mem}; +use std::env; +use std::sync::Arc; use oxidetalis_config::Config; use salvo::http::ResBody; @@ -24,18 +23,12 @@ use salvo::oapi::{Info, License}; use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer}; use salvo::{catcher::Catcher, logging::Logger, prelude::*}; +use crate::nonce::NonceCache; use crate::schemas::MessageSchema; -use crate::{middlewares, NonceCache}; +use crate::{middlewares, websocket}; mod user; -/// Size of each entry in the nonce cache -pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::(); -/// Size of the hashmap itself without the entrys (48 bytes) -pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::>(); -/// Name of the nonce cache size in the depot -pub(crate) const DEPOT_NONCE_CACHE_SIZE: &str = "NONCE_CACHE_SIZE"; - pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) { res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail")) .ok(); @@ -69,7 +62,7 @@ async fn handle_server_errors(res: &mut Response, ctrl: &mut FlowCtrl) { err.brief.trim_end_matches('.'), err.cause .as_deref() - .map_or_else(|| "".to_owned(), ToString::to_string) + .map_or_else(String::new, ToString::to_string) .trim_end_matches('.') .split(':') .last() @@ -129,23 +122,19 @@ fn route_openapi(config: &Config, router: Router) -> Router { } pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service { - let nonce_cache_size = config.server.nonce_cache_size.as_bytes(); - let nonce_cache: NonceCache = Mutex::new(HashMap::with_capacity( - (nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE, - )); + let nonce_cache: NonceCache = NonceCache::new(&config.server.nonce_cache_size); log::info!( - "Nonce cache created with a capacity of {} ({})", - (nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE, + "Nonce cache created with a capacity of {}", config.server.nonce_cache_size ); let router = Router::new() .push(Router::with_path("user").push(user::route())) + .push(Router::with_path("ws").push(websocket::route())) .hoop(middlewares::add_server_headers) .hoop(Logger::new()) .hoop( affix::inject(Arc::new(conn)) - .insert(DEPOT_NONCE_CACHE_SIZE, Arc::new(nonce_cache_size)) .inject(Arc::new(config.clone())) .inject(Arc::new(nonce_cache)), ); diff --git a/crates/oxidetalis/src/utils.rs b/crates/oxidetalis/src/utils.rs index 76f10c1..861926b 100644 --- a/crates/oxidetalis/src/utils.rs +++ b/crates/oxidetalis/src/utils.rs @@ -27,7 +27,7 @@ use oxidetalis_core::{ }; use salvo::Request; -use crate::{extensions::NonceCacheExt, NonceCache}; +use crate::nonce::NonceCache; /// Returns the postgres database url #[logcall] @@ -43,16 +43,12 @@ pub(crate) fn postgres_url(db_config: &Postgres) -> String { } /// Returns true if the given nonce a valid nonce. -pub(crate) fn is_valid_nonce( - signature: &Signature, - nonce_cache: &NonceCache, - nonce_cache_limit: &usize, -) -> bool { +pub(crate) async fn is_valid_nonce(signature: &Signature, nonce_cache: &NonceCache) -> bool { let new_timestamp = Utc::now() .timestamp() .checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64) .is_some_and(|n| n <= 20); - let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce(), nonce_cache_limit); + let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce()).await; new_timestamp && unused_nonce } diff --git a/crates/oxidetalis/src/websocket/errors.rs b/crates/oxidetalis/src/websocket/errors.rs new file mode 100644 index 0000000..b05e550 --- /dev/null +++ b/crates/oxidetalis/src/websocket/errors.rs @@ -0,0 +1,57 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Websocket errors + +/// Result type of websocket +pub type WsResult = Result; + +/// Websocket errors, returned in the websocket communication +#[derive(Debug)] +pub enum WsError { + /// The signature is invalid + InvalidSignature, + /// Message type must be text + NotTextMessage, + /// Invalid json data + InvalidJsonData, + /// Unknown client event + UnknownClientEvent, +} + +impl WsError { + /// Returns error name + pub const fn name(&self) -> &'static str { + match self { + WsError::InvalidSignature => "InvalidSignature", + WsError::NotTextMessage => "NotTextMessage", + WsError::InvalidJsonData => "InvalidJsonData", + WsError::UnknownClientEvent => "UnknownClientEvent", + } + } + + /// Returns the error reason + pub const fn reason(&self) -> &'static str { + match self { + WsError::InvalidSignature => "Invalid event signature", + WsError::NotTextMessage => "The websocket message must be text message", + WsError::InvalidJsonData => "Received invalid json data, the text must be valid json", + WsError::UnknownClientEvent => { + "Unknown client event, the event is not recognized by the server" + } + } + } +} diff --git a/crates/oxidetalis/src/websocket/events/client.rs b/crates/oxidetalis/src/websocket/events/client.rs new file mode 100644 index 0000000..a4dda31 --- /dev/null +++ b/crates/oxidetalis/src/websocket/events/client.rs @@ -0,0 +1,60 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Events that the client send it + +use oxidetalis_core::types::Signature; +use serde::{Deserialize, Serialize}; + +use crate::{nonce::NonceCache, utils}; + +/// Client websocket event +#[derive(Deserialize, Clone, Debug)] +pub struct ClientEvent { + pub event: ClientEventType, + signature: Signature, +} + +/// Client websocket event type +#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)] +#[serde(rename_all = "PascalCase", tag = "event", content = "data")] +pub enum ClientEventType { + /// Ping event + Ping { timestamp: u64 }, + /// Pong event + Pong { timestamp: u64 }, +} + +impl ClientEventType { + /// Returns event data as json bytes + pub fn data(&self) -> Vec { + serde_json::to_value(self).expect("can't fail")["data"] + .to_string() + .into_bytes() + } +} + +impl ClientEvent { + /// Verify the signature of the event + pub async fn verify_signature( + &self, + shared_secret: &[u8; 32], + nonce_cache: &NonceCache, + ) -> bool { + utils::is_valid_nonce(&self.signature, nonce_cache).await + && self.signature.verify(&self.event.data(), shared_secret) + } +} diff --git a/crates/oxidetalis/src/websocket/events/mod.rs b/crates/oxidetalis/src/websocket/events/mod.rs new file mode 100644 index 0000000..6690d70 --- /dev/null +++ b/crates/oxidetalis/src/websocket/events/mod.rs @@ -0,0 +1,23 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Server and client websocket events + +mod client; +mod server; + +pub use client::*; +pub use server::*; diff --git a/crates/oxidetalis/src/websocket/events/server.rs b/crates/oxidetalis/src/websocket/events/server.rs new file mode 100644 index 0000000..225b129 --- /dev/null +++ b/crates/oxidetalis/src/websocket/events/server.rs @@ -0,0 +1,117 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//! Events that the server send it + +use std::marker::PhantomData; + +use chrono::Utc; +use oxidetalis_core::{cipher::K256Secret, types::Signature}; +use salvo::websocket::Message; +use serde::Serialize; + +use crate::websocket::errors::WsError; + +/// Signed marker, used to indicate that the event is signed +pub struct Signed; +/// Unsigned marker, used to indicate that the event is unsigned +pub struct Unsigned; + +/// Server websocket event +#[derive(Serialize, Clone, Debug)] +pub struct ServerEvent { + #[serde(flatten)] + event: ServerEventType, + signature: Signature, + #[serde(skip)] + phantom: PhantomData, +} + +/// server websocket event type +#[derive(Serialize, Clone, Eq, PartialEq, Debug)] +#[serde(rename_all = "PascalCase")] +pub enum ServerEventType { + /// Ping event + Ping { timestamp: u64 }, + /// Pong event + Pong { timestamp: u64 }, + /// Error event + Error { + name: &'static str, + reason: &'static str, + }, +} + +impl ServerEventType { + /// Returns event data as json bytes + pub fn data(&self) -> Vec { + serde_json::to_value(self).expect("can't fail")["data"] + .to_string() + .into_bytes() + } +} + +impl ServerEvent { + /// Creates new [`ServerEvent`] + pub fn new(event: ServerEventType) -> Self { + Self { + event, + signature: Signature::from([0u8; 56]), + phantom: PhantomData, + } + } + + /// Creates ping event + pub fn ping() -> Self { + Self::new(ServerEventType::Ping { + timestamp: Utc::now().timestamp() as u64, + }) + } + + /// Creates pong event + pub fn pong() -> Self { + Self::new(ServerEventType::Pong { + timestamp: Utc::now().timestamp() as u64, + }) + } + + /// Sign the event + pub fn sign(self, shared_secret: &[u8; 32]) -> ServerEvent { + ServerEvent:: { + signature: K256Secret::sign_with_shared_secret( + &serde_json::to_vec(&self.event.data()).expect("Can't fail"), + shared_secret, + ), + event: self.event, + phantom: PhantomData, + } + } +} + +impl From<&ServerEvent> for Message { + fn from(value: &ServerEvent) -> Self { + Message::text(serde_json::to_string(value).expect("This can't fail")) + } +} + +impl From for ServerEvent { + fn from(err: WsError) -> Self { + ServerEvent::new(ServerEventType::Error { + name: err.name(), + reason: err.reason(), + }) + } +} diff --git a/crates/oxidetalis/src/websocket/mod.rs b/crates/oxidetalis/src/websocket/mod.rs new file mode 100644 index 0000000..c91479e --- /dev/null +++ b/crates/oxidetalis/src/websocket/mod.rs @@ -0,0 +1,219 @@ +// OxideTalis Messaging Protocol homeserver implementation +// Copyright (C) 2024 OxideTalis Developers +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use chrono::Utc; +use errors::{WsError, WsResult}; +use futures::{channel::mpsc, FutureExt, StreamExt, TryStreamExt}; +use once_cell::sync::Lazy; +use oxidetalis_core::{cipher::K256Secret, types::PublicKey}; +use salvo::{ + handler, + http::StatusError, + websocket::{Message, WebSocket, WebSocketUpgrade}, + Depot, + Request, + Response, + Router, +}; +use tokio::{sync::RwLock, task::spawn as tokio_spawn, time::sleep as tokio_sleep}; + +mod errors; +mod events; + +pub use events::*; +use uuid::Uuid; + +use crate::{ + extensions::{DepotExt, OnlineUsersExt}, + middlewares, + nonce::NonceCache, + utils, +}; + +/// Online users type +pub type OnlineUsers = RwLock>; + +/// List of online users, users that are connected to the server +// FIXME: Use `std::sync::LazyLock` after it becomes stable in `1.80.0` +static ONLINE_USERS: Lazy = Lazy::new(OnlineUsers::default); + +/// A user connected to the server +pub struct SocketUserData { + /// Sender to send messages to the user + pub sender: mpsc::UnboundedSender>, + /// User public key + pub public_key: PublicKey, + /// Time that the user pinged at + pub pinged_at: chrono::DateTime, + /// Time that the user ponged at + pub ponged_at: chrono::DateTime, + /// User shared secret + pub shared_secret: [u8; 32], +} + +impl SocketUserData { + /// Creates new [`SocketUserData`] + pub fn new( + public_key: PublicKey, + shared_secret: [u8; 32], + sender: mpsc::UnboundedSender>, + ) -> Self { + let now = Utc::now(); + Self { + sender, + public_key, + shared_secret, + pinged_at: now, + ponged_at: now, + } + } +} + +/// WebSocket handler, that handles the user connection +#[handler] +pub async fn user_connected( + req: &mut Request, + res: &mut Response, + depot: &Depot, +) -> Result<(), StatusError> { + let nonce_cache = depot.nonce_cache(); + let public_key = + utils::extract_public_key(req).expect("The public key was checked in the middleware"); + // FIXME: The config should hold `K256Secret` not `PrivateKey` + let shared_secret = + K256Secret::from_privkey(&depot.config().server.private_key).shared_secret(&public_key); + + WebSocketUpgrade::new() + .upgrade(req, res, move |ws| { + handle_socket(ws, nonce_cache, public_key, shared_secret) + }) + .await +} + +/// Handle the websocket connection +async fn handle_socket( + ws: WebSocket, + nonce_cache: Arc, + user_public_key: PublicKey, + user_shared_secret: [u8; 32], +) { + let (user_ws_sender, mut user_ws_receiver) = ws.split(); + + let (sender, receiver) = mpsc::unbounded(); + let receiver = receiver.into_stream(); + let fut = receiver.forward(user_ws_sender).map(|result| { + if let Err(err) = result { + log::error!("websocket send error: {err}"); + } + }); + tokio_spawn(fut); + let conn_id = Uuid::new_v4(); + let user = SocketUserData::new(user_public_key, user_shared_secret, sender.clone()); + ONLINE_USERS.add_user(&conn_id, user).await; + log::info!("New user connected: ConnId(={conn_id}) PublicKey(={user_public_key})"); + + let fut = async move { + while let Some(Ok(msg)) = user_ws_receiver.next().await { + match handle_ws_msg(msg, &nonce_cache, &user_shared_secret).await { + Ok(event) => { + if let Some(server_event) = handle_events(event, &conn_id).await { + if let Err(err) = sender.unbounded_send(Ok(Message::from( + &server_event.sign(&user_shared_secret), + ))) { + log::error!("Websocket Error: {err}"); + break; + } + }; + } + Err(err) => { + if let Err(err) = sender.unbounded_send(Ok(Message::from( + &ServerEvent::from(err).sign(&user_shared_secret), + ))) { + log::error!("Websocket Error: {err}"); + break; + }; + } + }; + } + user_disconnected(&conn_id, &user_public_key).await; + }; + tokio_spawn(fut); +} + +/// Handle websocket msg +async fn handle_ws_msg( + msg: Message, + nonce_cache: &NonceCache, + shared_secret: &[u8; 32], +) -> WsResult { + let Ok(text) = msg.to_str() else { + return Err(WsError::NotTextMessage); + }; + let event = serde_json::from_str::(text).map_err(|err| { + if err.is_data() { + WsError::UnknownClientEvent + } else { + WsError::InvalidJsonData + } + })?; + if !event.verify_signature(shared_secret, nonce_cache).await { + return Err(WsError::InvalidSignature); + } + Ok(event) +} + +/// Handle user events, and return the server event if needed +async fn handle_events(event: ClientEvent, conn_id: &Uuid) -> Option> { + match &event.event { + ClientEventType::Ping { .. } => Some(ServerEvent::pong()), + ClientEventType::Pong { .. } => { + ONLINE_USERS.update_pong(conn_id).await; + None + } + } +} + +/// Handle user disconnected +async fn user_disconnected(conn_id: &Uuid, public_key: &PublicKey) { + ONLINE_USERS.remove_user(conn_id).await; + log::debug!("User disconnect: ConnId(={conn_id}) PublicKey(={public_key})"); +} + +pub fn route() -> Router { + let users_pinger = async { + /// Seconds to wait for pongs, before disconnecting the user + const WAIT_FOR_PONGS_SECS: u32 = 10; + /// Seconds to sleep between pings (10 minutes) + const SLEEP_SECS: u32 = 60 * 10; + loop { + log::debug!("Start pinging online users"); + ONLINE_USERS.ping_all().await; + tokio_sleep(Duration::from_secs(u64::from(WAIT_FOR_PONGS_SECS))).await; + ONLINE_USERS.disconnect_inactive_users().await; + log::debug!("Done pinging online users and disconnected inactive ones"); + tokio_sleep(Duration::from_secs(u64::from(SLEEP_SECS))).await; + } + }; + + tokio_spawn(users_pinger); + + Router::new() + .push(Router::with_path("chat").get(user_connected)) + .hoop(middlewares::signature_check) + .hoop(middlewares::public_key_check) +}