From 6f43e44745bb8a3f57ac9c7aa66ef9ad049f88a8 Mon Sep 17 00:00:00 2001
From: Awiteb
Date: Fri, 5 Jul 2024 02:16:10 +0300
Subject: [PATCH 1/4] chore: Add new dependencies for websocket
Signed-off-by: Awiteb
---
Cargo.lock | 35 +++++++++++++++++++++++++++++++++++
crates/oxidetalis/Cargo.toml | 4 ++++
2 files changed, 39 insertions(+)
diff --git a/Cargo.lock b/Cargo.lock
index 3e48080..8c8c1fc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -626,6 +626,16 @@ dependencies = [
"crossbeam-utils",
]
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
+dependencies = [
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
@@ -958,6 +968,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [
"futures-channel",
"futures-core",
+ "futures-executor",
"futures-io",
"futures-sink",
"futures-task",
@@ -1895,19 +1906,23 @@ version = "0.1.0"
dependencies = [
"chrono",
"derive-new",
+ "futures",
"log",
"logcall",
+ "once_cell",
"oxidetalis_config",
"oxidetalis_core",
"oxidetalis_entities",
"oxidetalis_migrations",
"pretty_env_logger",
+ "rayon",
"salvo",
"sea-orm",
"serde",
"serde_json",
"thiserror",
"tokio",
+ "uuid",
]
[[package]]
@@ -2258,6 +2273,26 @@ dependencies = [
"bitflags 2.6.0",
]
+[[package]]
+name = "rayon"
+version = "1.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
+dependencies = [
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.12.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
+dependencies = [
+ "crossbeam-deque",
+ "crossbeam-utils",
+]
+
[[package]]
name = "redox_syscall"
version = "0.4.1"
diff --git a/crates/oxidetalis/Cargo.toml b/crates/oxidetalis/Cargo.toml
index 7a59970..f3a5764 100644
--- a/crates/oxidetalis/Cargo.toml
+++ b/crates/oxidetalis/Cargo.toml
@@ -23,9 +23,13 @@ thiserror = { workspace = true }
chrono = { workspace = true }
salvo = { version = "0.68.2", features = ["rustls", "affix", "logging", "oapi", "rate-limiter", "websocket"] }
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] }
+uuid = { version = "1.9.1", default-features = false, features = ["v4"] }
derive-new = "0.6.0"
pretty_env_logger = "0.5.0"
serde_json = "1.0.117"
+once_cell = "1.19.0"
+futures = "0.3.30"
+rayon = "1.10.0"
[lints.rust]
unsafe_code = "deny"
--
2.45.2
From cd2a9ea03ef08000401fd0a9807cfa36301d48f9 Mon Sep 17 00:00:00 2001
From: Awiteb
Date: Fri, 5 Jul 2024 02:17:10 +0300
Subject: [PATCH 2/4] feat: New extension trait for websocket online users
Signed-off-by: Awiteb
---
crates/oxidetalis/src/extensions.rs | 69 ++++++++++++++++++++++++++++-
1 file changed, 67 insertions(+), 2 deletions(-)
diff --git a/crates/oxidetalis/src/extensions.rs b/crates/oxidetalis/src/extensions.rs
index 9e6b2e0..7cc2c27 100644
--- a/crates/oxidetalis/src/extensions.rs
+++ b/crates/oxidetalis/src/extensions.rs
@@ -18,10 +18,16 @@ use std::sync::Arc;
use chrono::Utc;
use oxidetalis_config::Config;
-use salvo::Depot;
+use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
+use salvo::{websocket::Message, Depot};
use sea_orm::DatabaseConnection;
+use uuid::Uuid;
-use crate::{routes::DEPOT_NONCE_CACHE_SIZE, NonceCache};
+use crate::{
+ routes::DEPOT_NONCE_CACHE_SIZE,
+ websocket::{OnlineUsers, ServerEvent, SocketUserData},
+ NonceCache,
+};
/// Extension trait for the Depot.
pub trait DepotExt {
@@ -42,6 +48,24 @@ pub trait NonceCacheExt {
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool;
}
+/// Extension trait for online websocket users
+pub trait OnlineUsersExt {
+ /// Add new user to the online users
+ async fn add_user(&self, conn_id: &Uuid, data: SocketUserData);
+
+ /// Remove user from online users
+ async fn remove_user(&self, conn_id: &Uuid);
+
+ /// Ping all online users
+ async fn ping_all(&self);
+
+ /// Update user pong at time
+ async fn update_pong(&self, conn_id: &Uuid);
+
+ /// Disconnect inactive users (who not respond for the ping event)
+ async fn disconnect_inactive_users(&self);
+}
+
impl DepotExt for Depot {
fn db_conn(&self) -> &DatabaseConnection {
self.obtain::>()
@@ -91,3 +115,44 @@ impl NonceCacheExt for &NonceCache {
true
}
}
+
+impl OnlineUsersExt for OnlineUsers {
+ async fn add_user(&self, conn_id: &Uuid, data: SocketUserData) {
+ self.write().await.insert(*conn_id, data);
+ }
+
+ async fn remove_user(&self, conn_id: &Uuid) {
+ self.write().await.remove(conn_id);
+ }
+
+ async fn ping_all(&self) {
+ let now = Utc::now();
+ self.write().await.par_iter_mut().for_each(|(_, u)| {
+ u.pinged_at = now;
+ let _ = u.sender.unbounded_send(Ok(Message::from(
+ &ServerEvent::ping().sign(&u.shared_secret),
+ )));
+ });
+ }
+
+ async fn update_pong(&self, conn_id: &Uuid) {
+ if let Some(user) = self.write().await.get_mut(conn_id) {
+ user.ponged_at = Utc::now()
+ }
+ }
+
+ async fn disconnect_inactive_users(&self) {
+ let now = Utc::now().timestamp();
+ let is_inactive =
+ |u: &SocketUserData| u.pinged_at > u.ponged_at && now - u.pinged_at.timestamp() >= 5;
+ self.read()
+ .await
+ .iter()
+ .filter(|(_, u)| is_inactive(u))
+ .for_each(|(_, u)| {
+ log::info!("Disconnected from {}, inactive", u.public_key);
+ u.sender.close_channel()
+ });
+ self.write().await.retain(|_, u| !is_inactive(u));
+ }
+}
--
2.45.2
From c0d5efe0c342d87849bb1cb295f9a2f3dfc64a70 Mon Sep 17 00:00:00 2001
From: Awiteb
Date: Fri, 5 Jul 2024 02:19:58 +0300
Subject: [PATCH 3/4] feat: Initialize server websocket
Related-to: https://git.4rs.nl/OxideTalis/oxidetalis/issues/2
Reviewed-by: Amjad Alsharafi
Signed-off-by: Awiteb
---
crates/oxidetalis/src/extensions.rs | 19 +-
crates/oxidetalis/src/main.rs | 1 +
crates/oxidetalis/src/routes/mod.rs | 3 +-
crates/oxidetalis/src/websocket/errors.rs | 57 +++++
.../oxidetalis/src/websocket/events/client.rs | 61 +++++
crates/oxidetalis/src/websocket/events/mod.rs | 23 ++
.../oxidetalis/src/websocket/events/server.rs | 117 +++++++++
crates/oxidetalis/src/websocket/mod.rs | 222 ++++++++++++++++++
8 files changed, 491 insertions(+), 12 deletions(-)
create mode 100644 crates/oxidetalis/src/websocket/errors.rs
create mode 100644 crates/oxidetalis/src/websocket/events/client.rs
create mode 100644 crates/oxidetalis/src/websocket/events/mod.rs
create mode 100644 crates/oxidetalis/src/websocket/events/server.rs
create mode 100644 crates/oxidetalis/src/websocket/mod.rs
diff --git a/crates/oxidetalis/src/extensions.rs b/crates/oxidetalis/src/extensions.rs
index 7cc2c27..bc06067 100644
--- a/crates/oxidetalis/src/extensions.rs
+++ b/crates/oxidetalis/src/extensions.rs
@@ -142,17 +142,14 @@ impl OnlineUsersExt for OnlineUsers {
}
async fn disconnect_inactive_users(&self) {
- let now = Utc::now().timestamp();
- let is_inactive =
- |u: &SocketUserData| u.pinged_at > u.ponged_at && now - u.pinged_at.timestamp() >= 5;
- self.read()
- .await
- .iter()
- .filter(|(_, u)| is_inactive(u))
- .for_each(|(_, u)| {
+ self.write().await.retain(|_, u| {
+ // if we send ping and the client doesn't send pong
+ if u.pinged_at > u.ponged_at {
log::info!("Disconnected from {}, inactive", u.public_key);
- u.sender.close_channel()
- });
- self.write().await.retain(|_, u| !is_inactive(u));
+ u.sender.close_channel();
+ return false;
+ }
+ true
+ });
}
}
diff --git a/crates/oxidetalis/src/main.rs b/crates/oxidetalis/src/main.rs
index b770e18..a440c74 100644
--- a/crates/oxidetalis/src/main.rs
+++ b/crates/oxidetalis/src/main.rs
@@ -30,6 +30,7 @@ mod middlewares;
mod routes;
mod schemas;
mod utils;
+mod websocket;
/// Nonce cache type, used to store nonces for a certain amount of time
pub type NonceCache = Mutex>;
diff --git a/crates/oxidetalis/src/routes/mod.rs b/crates/oxidetalis/src/routes/mod.rs
index 47fddb8..06b9a63 100644
--- a/crates/oxidetalis/src/routes/mod.rs
+++ b/crates/oxidetalis/src/routes/mod.rs
@@ -25,7 +25,7 @@ use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, Remote
use salvo::{catcher::Catcher, logging::Logger, prelude::*};
use crate::schemas::MessageSchema;
-use crate::{middlewares, NonceCache};
+use crate::{middlewares, websocket, NonceCache};
mod user;
@@ -141,6 +141,7 @@ pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
let router = Router::new()
.push(Router::with_path("user").push(user::route()))
+ .push(Router::with_path("ws").push(websocket::route()))
.hoop(middlewares::add_server_headers)
.hoop(Logger::new())
.hoop(
diff --git a/crates/oxidetalis/src/websocket/errors.rs b/crates/oxidetalis/src/websocket/errors.rs
new file mode 100644
index 0000000..b05e550
--- /dev/null
+++ b/crates/oxidetalis/src/websocket/errors.rs
@@ -0,0 +1,57 @@
+// OxideTalis Messaging Protocol homeserver implementation
+// Copyright (C) 2024 OxideTalis Developers
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//! Websocket errors
+
+/// Result type of websocket
+pub type WsResult = Result;
+
+/// Websocket errors, returned in the websocket communication
+#[derive(Debug)]
+pub enum WsError {
+ /// The signature is invalid
+ InvalidSignature,
+ /// Message type must be text
+ NotTextMessage,
+ /// Invalid json data
+ InvalidJsonData,
+ /// Unknown client event
+ UnknownClientEvent,
+}
+
+impl WsError {
+ /// Returns error name
+ pub const fn name(&self) -> &'static str {
+ match self {
+ WsError::InvalidSignature => "InvalidSignature",
+ WsError::NotTextMessage => "NotTextMessage",
+ WsError::InvalidJsonData => "InvalidJsonData",
+ WsError::UnknownClientEvent => "UnknownClientEvent",
+ }
+ }
+
+ /// Returns the error reason
+ pub const fn reason(&self) -> &'static str {
+ match self {
+ WsError::InvalidSignature => "Invalid event signature",
+ WsError::NotTextMessage => "The websocket message must be text message",
+ WsError::InvalidJsonData => "Received invalid json data, the text must be valid json",
+ WsError::UnknownClientEvent => {
+ "Unknown client event, the event is not recognized by the server"
+ }
+ }
+ }
+}
diff --git a/crates/oxidetalis/src/websocket/events/client.rs b/crates/oxidetalis/src/websocket/events/client.rs
new file mode 100644
index 0000000..25e688b
--- /dev/null
+++ b/crates/oxidetalis/src/websocket/events/client.rs
@@ -0,0 +1,61 @@
+// OxideTalis Messaging Protocol homeserver implementation
+// Copyright (C) 2024 OxideTalis Developers
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//! Events that the client send it
+
+use oxidetalis_core::types::Signature;
+use serde::{Deserialize, Serialize};
+
+use crate::{utils, NonceCache};
+
+/// Client websocket event
+#[derive(Deserialize, Clone, Debug)]
+pub struct ClientEvent {
+ pub event: ClientEventType,
+ signature: Signature,
+}
+
+/// Client websocket event type
+#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
+#[serde(rename_all = "PascalCase", tag = "event", content = "data")]
+pub enum ClientEventType {
+ /// Ping event
+ Ping { timestamp: u64 },
+ /// Pong event
+ Pong { timestamp: u64 },
+}
+
+impl ClientEventType {
+ /// Returns event data as json bytes
+ pub fn data(&self) -> Vec {
+ serde_json::to_value(self).expect("can't fail")["data"]
+ .to_string()
+ .into_bytes()
+ }
+}
+
+impl ClientEvent {
+ /// Verify the signature of the event
+ pub fn verify_signature(
+ &self,
+ shared_secret: &[u8; 32],
+ nonce_cache: &NonceCache,
+ nonce_limit: &usize,
+ ) -> bool {
+ utils::is_valid_nonce(&self.signature, nonce_cache, nonce_limit)
+ && self.signature.verify(&self.event.data(), shared_secret)
+ }
+}
diff --git a/crates/oxidetalis/src/websocket/events/mod.rs b/crates/oxidetalis/src/websocket/events/mod.rs
new file mode 100644
index 0000000..6690d70
--- /dev/null
+++ b/crates/oxidetalis/src/websocket/events/mod.rs
@@ -0,0 +1,23 @@
+// OxideTalis Messaging Protocol homeserver implementation
+// Copyright (C) 2024 OxideTalis Developers
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//! Server and client websocket events
+
+mod client;
+mod server;
+
+pub use client::*;
+pub use server::*;
diff --git a/crates/oxidetalis/src/websocket/events/server.rs b/crates/oxidetalis/src/websocket/events/server.rs
new file mode 100644
index 0000000..225b129
--- /dev/null
+++ b/crates/oxidetalis/src/websocket/events/server.rs
@@ -0,0 +1,117 @@
+// OxideTalis Messaging Protocol homeserver implementation
+// Copyright (C) 2024 OxideTalis Developers
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//! Events that the server send it
+
+use std::marker::PhantomData;
+
+use chrono::Utc;
+use oxidetalis_core::{cipher::K256Secret, types::Signature};
+use salvo::websocket::Message;
+use serde::Serialize;
+
+use crate::websocket::errors::WsError;
+
+/// Signed marker, used to indicate that the event is signed
+pub struct Signed;
+/// Unsigned marker, used to indicate that the event is unsigned
+pub struct Unsigned;
+
+/// Server websocket event
+#[derive(Serialize, Clone, Debug)]
+pub struct ServerEvent {
+ #[serde(flatten)]
+ event: ServerEventType,
+ signature: Signature,
+ #[serde(skip)]
+ phantom: PhantomData,
+}
+
+/// server websocket event type
+#[derive(Serialize, Clone, Eq, PartialEq, Debug)]
+#[serde(rename_all = "PascalCase")]
+pub enum ServerEventType {
+ /// Ping event
+ Ping { timestamp: u64 },
+ /// Pong event
+ Pong { timestamp: u64 },
+ /// Error event
+ Error {
+ name: &'static str,
+ reason: &'static str,
+ },
+}
+
+impl ServerEventType {
+ /// Returns event data as json bytes
+ pub fn data(&self) -> Vec {
+ serde_json::to_value(self).expect("can't fail")["data"]
+ .to_string()
+ .into_bytes()
+ }
+}
+
+impl ServerEvent {
+ /// Creates new [`ServerEvent`]
+ pub fn new(event: ServerEventType) -> Self {
+ Self {
+ event,
+ signature: Signature::from([0u8; 56]),
+ phantom: PhantomData,
+ }
+ }
+
+ /// Creates ping event
+ pub fn ping() -> Self {
+ Self::new(ServerEventType::Ping {
+ timestamp: Utc::now().timestamp() as u64,
+ })
+ }
+
+ /// Creates pong event
+ pub fn pong() -> Self {
+ Self::new(ServerEventType::Pong {
+ timestamp: Utc::now().timestamp() as u64,
+ })
+ }
+
+ /// Sign the event
+ pub fn sign(self, shared_secret: &[u8; 32]) -> ServerEvent {
+ ServerEvent:: {
+ signature: K256Secret::sign_with_shared_secret(
+ &serde_json::to_vec(&self.event.data()).expect("Can't fail"),
+ shared_secret,
+ ),
+ event: self.event,
+ phantom: PhantomData,
+ }
+ }
+}
+
+impl From<&ServerEvent> for Message {
+ fn from(value: &ServerEvent) -> Self {
+ Message::text(serde_json::to_string(value).expect("This can't fail"))
+ }
+}
+
+impl From for ServerEvent {
+ fn from(err: WsError) -> Self {
+ ServerEvent::new(ServerEventType::Error {
+ name: err.name(),
+ reason: err.reason(),
+ })
+ }
+}
diff --git a/crates/oxidetalis/src/websocket/mod.rs b/crates/oxidetalis/src/websocket/mod.rs
new file mode 100644
index 0000000..4f8aa6a
--- /dev/null
+++ b/crates/oxidetalis/src/websocket/mod.rs
@@ -0,0 +1,222 @@
+// OxideTalis Messaging Protocol homeserver implementation
+// Copyright (C) 2024 OxideTalis Developers
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+use std::{collections::HashMap, sync::Arc, time::Duration};
+
+use chrono::Utc;
+use errors::{WsError, WsResult};
+use futures::{channel::mpsc, FutureExt, StreamExt, TryStreamExt};
+use once_cell::sync::Lazy;
+use oxidetalis_core::{cipher::K256Secret, types::PublicKey};
+use salvo::{
+ handler,
+ http::StatusError,
+ websocket::{Message, WebSocket, WebSocketUpgrade},
+ Depot,
+ Request,
+ Response,
+ Router,
+};
+use tokio::{sync::RwLock, task::spawn as tokio_spawn, time::sleep as tokio_sleep};
+
+mod errors;
+mod events;
+
+pub use events::*;
+use uuid::Uuid;
+
+use crate::{
+ extensions::{DepotExt, OnlineUsersExt},
+ middlewares,
+ utils,
+ NonceCache,
+};
+
+/// Online users type
+pub type OnlineUsers = RwLock>;
+
+/// List of online users, users that are connected to the server
+// FIXME: Use `std::sync::LazyLock` after it becomes stable in `1.80.0`
+static ONLINE_USERS: Lazy = Lazy::new(OnlineUsers::default);
+
+/// A user connected to the server
+pub struct SocketUserData {
+ /// Sender to send messages to the user
+ pub sender: mpsc::UnboundedSender>,
+ /// User public key
+ pub public_key: PublicKey,
+ /// Time that the user pinged at
+ pub pinged_at: chrono::DateTime,
+ /// Time that the user ponged at
+ pub ponged_at: chrono::DateTime,
+ /// User shared secret
+ pub shared_secret: [u8; 32],
+}
+
+impl SocketUserData {
+ /// Creates new [`SocketUserData`]
+ pub fn new(
+ public_key: PublicKey,
+ shared_secret: [u8; 32],
+ sender: mpsc::UnboundedSender>,
+ ) -> Self {
+ let now = Utc::now();
+ Self {
+ sender,
+ public_key,
+ shared_secret,
+ pinged_at: now,
+ ponged_at: now,
+ }
+ }
+}
+
+/// WebSocket handler, that handles the user connection
+#[handler]
+pub async fn user_connected(
+ req: &mut Request,
+ res: &mut Response,
+ depot: &Depot,
+) -> Result<(), StatusError> {
+ let nonce_cache = depot.nonce_cache();
+ let nonce_limit = *depot.nonce_cache_size();
+ let public_key =
+ utils::extract_public_key(req).expect("The public key was checked in the middleware");
+ // FIXME: The config should hold `K256Secret` not `PrivateKey`
+ let shared_secret =
+ K256Secret::from_privkey(&depot.config().server.private_key).shared_secret(&public_key);
+
+ WebSocketUpgrade::new()
+ .upgrade(req, res, move |ws| {
+ handle_socket(ws, nonce_cache, nonce_limit, public_key, shared_secret)
+ })
+ .await
+}
+
+/// Handle the websocket connection
+async fn handle_socket(
+ ws: WebSocket,
+ nonce_cache: Arc,
+ nonce_limit: usize,
+ user_public_key: PublicKey,
+ user_shared_secret: [u8; 32],
+) {
+ let (user_ws_sender, mut user_ws_receiver) = ws.split();
+
+ let (sender, receiver) = mpsc::unbounded();
+ let receiver = receiver.into_stream();
+ let fut = receiver.forward(user_ws_sender).map(|result| {
+ if let Err(err) = result {
+ log::error!("websocket send error: {err}");
+ }
+ });
+ tokio_spawn(fut);
+ let conn_id = Uuid::new_v4();
+ let user = SocketUserData::new(user_public_key, user_shared_secret, sender.clone());
+ ONLINE_USERS.add_user(&conn_id, user).await;
+ log::info!("New user connected: ConnId(={conn_id}) PublicKey(={user_public_key})");
+
+ let fut = async move {
+ while let Some(Ok(msg)) = user_ws_receiver.next().await {
+ match handle_ws_msg(msg, &nonce_cache, &nonce_limit, &user_shared_secret) {
+ Ok(event) => {
+ if let Some(server_event) = handle_events(event, &conn_id).await {
+ if let Err(err) = sender.unbounded_send(Ok(Message::from(
+ &server_event.sign(&user_shared_secret),
+ ))) {
+ log::error!("Websocket Error: {err}");
+ break;
+ }
+ };
+ }
+ Err(err) => {
+ if let Err(err) = sender.unbounded_send(Ok(Message::from(
+ &ServerEvent::from(err).sign(&user_shared_secret),
+ ))) {
+ log::error!("Websocket Error: {err}");
+ break;
+ };
+ }
+ };
+ }
+ user_disconnected(&conn_id, &user_public_key).await;
+ };
+ tokio_spawn(fut);
+}
+
+/// Handle websocket msg
+fn handle_ws_msg(
+ msg: Message,
+ nonce_cache: &NonceCache,
+ nonce_limit: &usize,
+ shared_secret: &[u8; 32],
+) -> WsResult {
+ let Ok(text) = msg.to_str() else {
+ return Err(WsError::NotTextMessage);
+ };
+ let event = serde_json::from_str::(text).map_err(|err| {
+ if err.is_data() {
+ WsError::UnknownClientEvent
+ } else {
+ WsError::InvalidJsonData
+ }
+ })?;
+ if !event.verify_signature(shared_secret, nonce_cache, nonce_limit) {
+ return Err(WsError::InvalidSignature);
+ }
+ Ok(event)
+}
+
+/// Handle user events, and return the server event if needed
+async fn handle_events(event: ClientEvent, conn_id: &Uuid) -> Option> {
+ match &event.event {
+ ClientEventType::Ping { .. } => Some(ServerEvent::pong()),
+ ClientEventType::Pong { .. } => {
+ ONLINE_USERS.update_pong(conn_id).await;
+ None
+ }
+ }
+}
+
+/// Handle user disconnected
+async fn user_disconnected(conn_id: &Uuid, public_key: &PublicKey) {
+ ONLINE_USERS.remove_user(conn_id).await;
+ log::debug!("User disconnect: ConnId(={conn_id}) PublicKey(={public_key})");
+}
+
+pub fn route() -> Router {
+ let users_pinger = async {
+ /// Seconds to wait for pongs, before disconnecting the user
+ const WAIT_FOR_PONGS_SECS: u32 = 10;
+ /// Seconds to sleep between pings (10 minutes)
+ const SLEEP_SECS: u32 = 60 * 10;
+ loop {
+ log::debug!("Start pinging online users");
+ ONLINE_USERS.ping_all().await;
+ tokio_sleep(Duration::from_secs(u64::from(WAIT_FOR_PONGS_SECS))).await;
+ ONLINE_USERS.disconnect_inactive_users().await;
+ log::debug!("Done pinging online users and disconnected inactive ones");
+ tokio_sleep(Duration::from_secs(u64::from(SLEEP_SECS))).await;
+ }
+ };
+
+ tokio_spawn(users_pinger);
+
+ Router::new()
+ .push(Router::with_path("chat").get(user_connected))
+ .hoop(middlewares::signature_check)
+ .hoop(middlewares::public_key_check)
+}
--
2.45.2
From d442e73ed7253f2fa8c381029058d51627de25b3 Mon Sep 17 00:00:00 2001
From: Awiteb
Date: Sat, 6 Jul 2024 01:25:06 +0300
Subject: [PATCH 4/4] change: Remove the nonce cache limit
Don't pass the nonce cache limit everywhere, we allocate the hashmap capacity
with it, so we can use the hashmap capacity directly.
Also refactor the `NonceCache` type, make it better without extension trait.
Suggested-by: Amjad Alsharafi
Reviewed-by: Amjad Alsharafi
Signed-off-by: Awiteb
---
crates/oxidetalis/src/extensions.rs | 44 +----------
crates/oxidetalis/src/main.rs | 7 +-
.../oxidetalis/src/middlewares/signature.rs | 2 +-
crates/oxidetalis/src/nonce.rs | 73 +++++++++++++++++++
crates/oxidetalis/src/routes/mod.rs | 26 ++-----
crates/oxidetalis/src/utils.rs | 10 +--
.../oxidetalis/src/websocket/events/client.rs | 7 +-
crates/oxidetalis/src/websocket/mod.rs | 13 ++--
8 files changed, 96 insertions(+), 86 deletions(-)
create mode 100644 crates/oxidetalis/src/nonce.rs
diff --git a/crates/oxidetalis/src/extensions.rs b/crates/oxidetalis/src/extensions.rs
index bc06067..3f4415e 100644
--- a/crates/oxidetalis/src/extensions.rs
+++ b/crates/oxidetalis/src/extensions.rs
@@ -24,9 +24,8 @@ use sea_orm::DatabaseConnection;
use uuid::Uuid;
use crate::{
- routes::DEPOT_NONCE_CACHE_SIZE,
+ nonce::NonceCache,
websocket::{OnlineUsers, ServerEvent, SocketUserData},
- NonceCache,
};
/// Extension trait for the Depot.
@@ -37,15 +36,6 @@ pub trait DepotExt {
fn config(&self) -> &Config;
/// Retutns the nonce cache
fn nonce_cache(&self) -> Arc;
- /// Returns the size of the nonce cache
- fn nonce_cache_size(&self) -> &usize;
-}
-
-/// Extension trait for the nonce cache.
-pub trait NonceCacheExt {
- /// Add a nonce to the cache, returns `true` if the nonce is added, `false`
- /// if the nonce is already exist in the cache.
- fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool;
}
/// Extension trait for online websocket users
@@ -82,38 +72,6 @@ impl DepotExt for Depot {
.expect("Nonce cache not found"),
)
}
-
- fn nonce_cache_size(&self) -> &usize {
- let s: &Arc = self
- .get(DEPOT_NONCE_CACHE_SIZE)
- .expect("Nonce cache size not found");
- s.as_ref()
- }
-}
-
-impl NonceCacheExt for &NonceCache {
- fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool {
- let mut cache = self.lock().expect("Nonce cache lock poisoned, aborting...");
- let now = Utc::now().timestamp();
- cache.retain(|_, time| (now - *time) < 30);
-
- if &cache.len() >= limit {
- log::warn!("Nonce cache limit reached, clearing 10% of the cache");
- let num_to_remove = limit / 10;
- let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
- for key in keys.iter().take(num_to_remove) {
- cache.remove(key);
- }
- }
-
- // We can use insert directly, but it's will update the value if the key is
- // already exist so we need to check if the key is already exist or not
- if cache.contains_key(nonce) {
- return false;
- }
- cache.insert(*nonce, now);
- true
- }
}
impl OnlineUsersExt for OnlineUsers {
diff --git a/crates/oxidetalis/src/main.rs b/crates/oxidetalis/src/main.rs
index a440c74..53ad66d 100644
--- a/crates/oxidetalis/src/main.rs
+++ b/crates/oxidetalis/src/main.rs
@@ -17,7 +17,7 @@
#![doc = include_str!("../../../README.md")]
#![warn(missing_docs, unsafe_code)]
-use std::{collections::HashMap, process::ExitCode, sync::Mutex};
+use std::process::ExitCode;
use oxidetalis_config::{CliArgs, Parser};
use oxidetalis_migrations::MigratorTrait;
@@ -27,14 +27,12 @@ mod database;
mod errors;
mod extensions;
mod middlewares;
+mod nonce;
mod routes;
mod schemas;
mod utils;
mod websocket;
-/// Nonce cache type, used to store nonces for a certain amount of time
-pub type NonceCache = Mutex>;
-
async fn try_main() -> errors::Result<()> {
pretty_env_logger::init_timed();
@@ -50,6 +48,7 @@ async fn try_main() -> errors::Result<()> {
let local_addr = format!("{}:{}", config.server.host, config.server.port);
let acceptor = TcpListener::new(&local_addr).bind().await;
log::info!("Server listening on http://{local_addr}");
+ log::info!("Chat websocket on ws://{local_addr}/ws/chat");
if config.openapi.enable {
log::info!(
"The openapi schema is available at http://{local_addr}{}",
diff --git a/crates/oxidetalis/src/middlewares/signature.rs b/crates/oxidetalis/src/middlewares/signature.rs
index 463a1fc..c94d0de 100644
--- a/crates/oxidetalis/src/middlewares/signature.rs
+++ b/crates/oxidetalis/src/middlewares/signature.rs
@@ -70,7 +70,7 @@ pub async fn signature_check(
}
};
- if !utils::is_valid_nonce(&signature, &depot.nonce_cache(), depot.nonce_cache_size())
+ if !utils::is_valid_nonce(&signature, &depot.nonce_cache()).await
|| !utils::is_valid_signature(
&sender_public_key,
&depot.config().server.private_key,
diff --git a/crates/oxidetalis/src/nonce.rs b/crates/oxidetalis/src/nonce.rs
new file mode 100644
index 0000000..75b247d
--- /dev/null
+++ b/crates/oxidetalis/src/nonce.rs
@@ -0,0 +1,73 @@
+// OxideTalis Messaging Protocol homeserver implementation
+// Copyright (C) 2024 OxideTalis Developers
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as
+// published by the Free Software Foundation, either version 3 of the
+// License, or (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+//! Nonce cache implementation
+
+use std::{collections::HashMap, mem};
+
+use chrono::Utc;
+use oxidetalis_core::types::Size;
+use tokio::sync::Mutex as TokioMutex;
+
+/// Size of each entry in the nonce cache
+pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::();
+/// Size of the hashmap itself without the entrys (48 bytes)
+pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::>();
+
+/// Nonce cache struct, used to store nonces for a short period of time
+/// to prevent replay attacks, each nonce has a 30 seconds lifetime.
+///
+/// The cache will remove first 10% nonces if the cache limit is reached.
+pub struct NonceCache {
+ /// The nonce cache hashmap, the key is the nonce and the value is the time
+ cache: TokioMutex>,
+}
+
+impl NonceCache {
+ /// Creates new [`NonceCache`] instance, with the given cache limit
+ pub fn new(cache_limit: &Size) -> Self {
+ Self {
+ cache: TokioMutex::new(HashMap::with_capacity(
+ (cache_limit.as_bytes() - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
+ )),
+ }
+ }
+
+ /// Add a nonce to the cache, returns `true` if the nonce is added, `false`
+ /// if the nonce is already exist in the cache.
+ pub async fn add_nonce(&self, nonce: &[u8; 16]) -> bool {
+ let mut cache = self.cache.lock().await;
+ let now = Utc::now().timestamp();
+ cache.retain(|_, time| (now - *time) < 30);
+
+ if cache.len() == cache.capacity() {
+ log::warn!("Nonce cache limit reached, clearing 10% of the cache");
+ let num_to_remove = cache.capacity() / 10;
+ let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
+ for key in keys.iter().take(num_to_remove) {
+ cache.remove(key);
+ }
+ }
+
+ // We can use insert directly, but it's will update the value if the key is
+ // already exist so we need to check if the key is already exist or not
+ if cache.contains_key(nonce) {
+ return false;
+ }
+ cache.insert(*nonce, now);
+ true
+ }
+}
diff --git a/crates/oxidetalis/src/routes/mod.rs b/crates/oxidetalis/src/routes/mod.rs
index 06b9a63..45bdbd5 100644
--- a/crates/oxidetalis/src/routes/mod.rs
+++ b/crates/oxidetalis/src/routes/mod.rs
@@ -14,9 +14,8 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see .
-use std::collections::HashMap;
-use std::sync::{Arc, Mutex};
-use std::{env, mem};
+use std::env;
+use std::sync::Arc;
use oxidetalis_config::Config;
use salvo::http::ResBody;
@@ -24,18 +23,12 @@ use salvo::oapi::{Info, License};
use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer};
use salvo::{catcher::Catcher, logging::Logger, prelude::*};
+use crate::nonce::NonceCache;
use crate::schemas::MessageSchema;
-use crate::{middlewares, websocket, NonceCache};
+use crate::{middlewares, websocket};
mod user;
-/// Size of each entry in the nonce cache
-pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::();
-/// Size of the hashmap itself without the entrys (48 bytes)
-pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::>();
-/// Name of the nonce cache size in the depot
-pub(crate) const DEPOT_NONCE_CACHE_SIZE: &str = "NONCE_CACHE_SIZE";
-
pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) {
res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail"))
.ok();
@@ -69,7 +62,7 @@ async fn handle_server_errors(res: &mut Response, ctrl: &mut FlowCtrl) {
err.brief.trim_end_matches('.'),
err.cause
.as_deref()
- .map_or_else(|| "".to_owned(), ToString::to_string)
+ .map_or_else(String::new, ToString::to_string)
.trim_end_matches('.')
.split(':')
.last()
@@ -129,13 +122,9 @@ fn route_openapi(config: &Config, router: Router) -> Router {
}
pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
- let nonce_cache_size = config.server.nonce_cache_size.as_bytes();
- let nonce_cache: NonceCache = Mutex::new(HashMap::with_capacity(
- (nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
- ));
+ let nonce_cache: NonceCache = NonceCache::new(&config.server.nonce_cache_size);
log::info!(
- "Nonce cache created with a capacity of {} ({})",
- (nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
+ "Nonce cache created with a capacity of {}",
config.server.nonce_cache_size
);
@@ -146,7 +135,6 @@ pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
.hoop(Logger::new())
.hoop(
affix::inject(Arc::new(conn))
- .insert(DEPOT_NONCE_CACHE_SIZE, Arc::new(nonce_cache_size))
.inject(Arc::new(config.clone()))
.inject(Arc::new(nonce_cache)),
);
diff --git a/crates/oxidetalis/src/utils.rs b/crates/oxidetalis/src/utils.rs
index 76f10c1..861926b 100644
--- a/crates/oxidetalis/src/utils.rs
+++ b/crates/oxidetalis/src/utils.rs
@@ -27,7 +27,7 @@ use oxidetalis_core::{
};
use salvo::Request;
-use crate::{extensions::NonceCacheExt, NonceCache};
+use crate::nonce::NonceCache;
/// Returns the postgres database url
#[logcall]
@@ -43,16 +43,12 @@ pub(crate) fn postgres_url(db_config: &Postgres) -> String {
}
/// Returns true if the given nonce a valid nonce.
-pub(crate) fn is_valid_nonce(
- signature: &Signature,
- nonce_cache: &NonceCache,
- nonce_cache_limit: &usize,
-) -> bool {
+pub(crate) async fn is_valid_nonce(signature: &Signature, nonce_cache: &NonceCache) -> bool {
let new_timestamp = Utc::now()
.timestamp()
.checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64)
.is_some_and(|n| n <= 20);
- let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce(), nonce_cache_limit);
+ let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce()).await;
new_timestamp && unused_nonce
}
diff --git a/crates/oxidetalis/src/websocket/events/client.rs b/crates/oxidetalis/src/websocket/events/client.rs
index 25e688b..a4dda31 100644
--- a/crates/oxidetalis/src/websocket/events/client.rs
+++ b/crates/oxidetalis/src/websocket/events/client.rs
@@ -19,7 +19,7 @@
use oxidetalis_core::types::Signature;
use serde::{Deserialize, Serialize};
-use crate::{utils, NonceCache};
+use crate::{nonce::NonceCache, utils};
/// Client websocket event
#[derive(Deserialize, Clone, Debug)]
@@ -49,13 +49,12 @@ impl ClientEventType {
impl ClientEvent {
/// Verify the signature of the event
- pub fn verify_signature(
+ pub async fn verify_signature(
&self,
shared_secret: &[u8; 32],
nonce_cache: &NonceCache,
- nonce_limit: &usize,
) -> bool {
- utils::is_valid_nonce(&self.signature, nonce_cache, nonce_limit)
+ utils::is_valid_nonce(&self.signature, nonce_cache).await
&& self.signature.verify(&self.event.data(), shared_secret)
}
}
diff --git a/crates/oxidetalis/src/websocket/mod.rs b/crates/oxidetalis/src/websocket/mod.rs
index 4f8aa6a..c91479e 100644
--- a/crates/oxidetalis/src/websocket/mod.rs
+++ b/crates/oxidetalis/src/websocket/mod.rs
@@ -41,8 +41,8 @@ use uuid::Uuid;
use crate::{
extensions::{DepotExt, OnlineUsersExt},
middlewares,
+ nonce::NonceCache,
utils,
- NonceCache,
};
/// Online users type
@@ -92,7 +92,6 @@ pub async fn user_connected(
depot: &Depot,
) -> Result<(), StatusError> {
let nonce_cache = depot.nonce_cache();
- let nonce_limit = *depot.nonce_cache_size();
let public_key =
utils::extract_public_key(req).expect("The public key was checked in the middleware");
// FIXME: The config should hold `K256Secret` not `PrivateKey`
@@ -101,7 +100,7 @@ pub async fn user_connected(
WebSocketUpgrade::new()
.upgrade(req, res, move |ws| {
- handle_socket(ws, nonce_cache, nonce_limit, public_key, shared_secret)
+ handle_socket(ws, nonce_cache, public_key, shared_secret)
})
.await
}
@@ -110,7 +109,6 @@ pub async fn user_connected(
async fn handle_socket(
ws: WebSocket,
nonce_cache: Arc,
- nonce_limit: usize,
user_public_key: PublicKey,
user_shared_secret: [u8; 32],
) {
@@ -131,7 +129,7 @@ async fn handle_socket(
let fut = async move {
while let Some(Ok(msg)) = user_ws_receiver.next().await {
- match handle_ws_msg(msg, &nonce_cache, &nonce_limit, &user_shared_secret) {
+ match handle_ws_msg(msg, &nonce_cache, &user_shared_secret).await {
Ok(event) => {
if let Some(server_event) = handle_events(event, &conn_id).await {
if let Err(err) = sender.unbounded_send(Ok(Message::from(
@@ -158,10 +156,9 @@ async fn handle_socket(
}
/// Handle websocket msg
-fn handle_ws_msg(
+async fn handle_ws_msg(
msg: Message,
nonce_cache: &NonceCache,
- nonce_limit: &usize,
shared_secret: &[u8; 32],
) -> WsResult {
let Ok(text) = msg.to_str() else {
@@ -174,7 +171,7 @@ fn handle_ws_msg(
WsError::InvalidJsonData
}
})?;
- if !event.verify_signature(shared_secret, nonce_cache, nonce_limit) {
+ if !event.verify_signature(shared_secret, nonce_cache).await {
return Err(WsError::InvalidSignature);
}
Ok(event)
--
2.45.2