feat: Initialize server websocket #8

Merged
awiteb merged 4 commits from awiteb/init-websocket into master 2024-07-06 13:35:10 +02:00 AGit
13 changed files with 660 additions and 67 deletions

35
Cargo.lock generated
View file

@ -626,6 +626,16 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-epoch" name = "crossbeam-epoch"
version = "0.9.18" version = "0.9.18"
@ -958,6 +968,7 @@ checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-executor",
"futures-io", "futures-io",
"futures-sink", "futures-sink",
"futures-task", "futures-task",
@ -1895,19 +1906,23 @@ version = "0.1.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"derive-new", "derive-new",
"futures",
"log", "log",
"logcall", "logcall",
"once_cell",
"oxidetalis_config", "oxidetalis_config",
"oxidetalis_core", "oxidetalis_core",
"oxidetalis_entities", "oxidetalis_entities",
"oxidetalis_migrations", "oxidetalis_migrations",
"pretty_env_logger", "pretty_env_logger",
"rayon",
"salvo", "salvo",
"sea-orm", "sea-orm",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
"tokio", "tokio",
"uuid",
] ]
[[package]] [[package]]
@ -2258,6 +2273,26 @@ dependencies = [
"bitflags 2.6.0", "bitflags 2.6.0",
] ]
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.4.1" version = "0.4.1"

View file

@ -23,9 +23,13 @@ thiserror = { workspace = true }
chrono = { workspace = true } chrono = { workspace = true }
salvo = { version = "0.68.2", features = ["rustls", "affix", "logging", "oapi", "rate-limiter", "websocket"] } salvo = { version = "0.68.2", features = ["rustls", "affix", "logging", "oapi", "rate-limiter", "websocket"] }
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] }
uuid = { version = "1.9.1", default-features = false, features = ["v4"] }
derive-new = "0.6.0" derive-new = "0.6.0"
pretty_env_logger = "0.5.0" pretty_env_logger = "0.5.0"
serde_json = "1.0.117" serde_json = "1.0.117"
once_cell = "1.19.0"
futures = "0.3.30"
rayon = "1.10.0"
[lints.rust] [lints.rust]
unsafe_code = "deny" unsafe_code = "deny"

View file

@ -18,10 +18,15 @@ use std::sync::Arc;
use chrono::Utc; use chrono::Utc;
use oxidetalis_config::Config; use oxidetalis_config::Config;
use salvo::Depot; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
use salvo::{websocket::Message, Depot};
use sea_orm::DatabaseConnection; use sea_orm::DatabaseConnection;
use uuid::Uuid;
use crate::{routes::DEPOT_NONCE_CACHE_SIZE, NonceCache}; use crate::{
nonce::NonceCache,
websocket::{OnlineUsers, ServerEvent, SocketUserData},
};
/// Extension trait for the Depot. /// Extension trait for the Depot.
pub trait DepotExt { pub trait DepotExt {
@ -31,15 +36,24 @@ pub trait DepotExt {
fn config(&self) -> &Config; fn config(&self) -> &Config;
/// Retutns the nonce cache /// Retutns the nonce cache
fn nonce_cache(&self) -> Arc<NonceCache>; fn nonce_cache(&self) -> Arc<NonceCache>;
/// Returns the size of the nonce cache
fn nonce_cache_size(&self) -> &usize;
} }
/// Extension trait for the nonce cache. /// Extension trait for online websocket users
pub trait NonceCacheExt { pub trait OnlineUsersExt {
/// Add a nonce to the cache, returns `true` if the nonce is added, `false` /// Add new user to the online users
/// if the nonce is already exist in the cache. async fn add_user(&self, conn_id: &Uuid, data: SocketUserData);
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool;
/// Remove user from online users
async fn remove_user(&self, conn_id: &Uuid);
/// Ping all online users
async fn ping_all(&self);
/// Update user pong at time
async fn update_pong(&self, conn_id: &Uuid);
/// Disconnect inactive users (who not respond for the ping event)
async fn disconnect_inactive_users(&self);
} }
impl DepotExt for Depot { impl DepotExt for Depot {
@ -58,36 +72,42 @@ impl DepotExt for Depot {
.expect("Nonce cache not found"), .expect("Nonce cache not found"),
) )
} }
}
fn nonce_cache_size(&self) -> &usize { impl OnlineUsersExt for OnlineUsers {
let s: &Arc<usize> = self async fn add_user(&self, conn_id: &Uuid, data: SocketUserData) {
.get(DEPOT_NONCE_CACHE_SIZE) self.write().await.insert(*conn_id, data);
.expect("Nonce cache size not found"); }
s.as_ref()
async fn remove_user(&self, conn_id: &Uuid) {
self.write().await.remove(conn_id);
}
async fn ping_all(&self) {
let now = Utc::now();
self.write().await.par_iter_mut().for_each(|(_, u)| {
u.pinged_at = now;
let _ = u.sender.unbounded_send(Ok(Message::from(
&ServerEvent::ping().sign(&u.shared_secret),
)));
});
}
async fn update_pong(&self, conn_id: &Uuid) {
if let Some(user) = self.write().await.get_mut(conn_id) {
user.ponged_at = Utc::now()
} }
} }
impl NonceCacheExt for &NonceCache { async fn disconnect_inactive_users(&self) {
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool { self.write().await.retain(|_, u| {
let mut cache = self.lock().expect("Nonce cache lock poisoned, aborting..."); // if we send ping and the client doesn't send pong
let now = Utc::now().timestamp(); if u.pinged_at > u.ponged_at {
cache.retain(|_, time| (now - *time) < 30); log::info!("Disconnected from {}, inactive", u.public_key);
u.sender.close_channel();
if &cache.len() >= limit {
log::warn!("Nonce cache limit reached, clearing 10% of the cache");
let num_to_remove = limit / 10;
let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
for key in keys.iter().take(num_to_remove) {
cache.remove(key);
}
}
// We can use insert directly, but it's will update the value if the key is
// already exist so we need to check if the key is already exist or not
if cache.contains_key(nonce) {
return false; return false;
} }
cache.insert(*nonce, now);
true true
});
} }
} }

View file

@ -17,7 +17,7 @@
#![doc = include_str!("../../../README.md")] #![doc = include_str!("../../../README.md")]
#![warn(missing_docs, unsafe_code)] #![warn(missing_docs, unsafe_code)]
use std::{collections::HashMap, process::ExitCode, sync::Mutex}; use std::process::ExitCode;
use oxidetalis_config::{CliArgs, Parser}; use oxidetalis_config::{CliArgs, Parser};
use oxidetalis_migrations::MigratorTrait; use oxidetalis_migrations::MigratorTrait;
@ -27,12 +27,11 @@ mod database;
mod errors; mod errors;
mod extensions; mod extensions;
mod middlewares; mod middlewares;
mod nonce;
mod routes; mod routes;
mod schemas; mod schemas;
mod utils; mod utils;
mod websocket;
/// Nonce cache type, used to store nonces for a certain amount of time
pub type NonceCache = Mutex<HashMap<[u8; 16], i64>>;
async fn try_main() -> errors::Result<()> { async fn try_main() -> errors::Result<()> {
pretty_env_logger::init_timed(); pretty_env_logger::init_timed();
@ -49,6 +48,7 @@ async fn try_main() -> errors::Result<()> {
let local_addr = format!("{}:{}", config.server.host, config.server.port); let local_addr = format!("{}:{}", config.server.host, config.server.port);
let acceptor = TcpListener::new(&local_addr).bind().await; let acceptor = TcpListener::new(&local_addr).bind().await;
log::info!("Server listening on http://{local_addr}"); log::info!("Server listening on http://{local_addr}");
log::info!("Chat websocket on ws://{local_addr}/ws/chat");
if config.openapi.enable { if config.openapi.enable {
log::info!( log::info!(
"The openapi schema is available at http://{local_addr}{}", "The openapi schema is available at http://{local_addr}{}",

View file

@ -70,7 +70,7 @@ pub async fn signature_check(
} }
}; };
if !utils::is_valid_nonce(&signature, &depot.nonce_cache(), depot.nonce_cache_size()) if !utils::is_valid_nonce(&signature, &depot.nonce_cache()).await
|| !utils::is_valid_signature( || !utils::is_valid_signature(
&sender_public_key, &sender_public_key,
&depot.config().server.private_key, &depot.config().server.private_key,

View file

@ -0,0 +1,73 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Nonce cache implementation
use std::{collections::HashMap, mem};
use chrono::Utc;
use oxidetalis_core::types::Size;
use tokio::sync::Mutex as TokioMutex;
/// Size of each entry in the nonce cache
pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::<i16>();
/// Size of the hashmap itself without the entrys (48 bytes)
pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::<HashMap<u8, u8>>();
/// Nonce cache struct, used to store nonces for a short period of time
/// to prevent replay attacks, each nonce has a 30 seconds lifetime.
///
/// The cache will remove first 10% nonces if the cache limit is reached.
pub struct NonceCache {
/// The nonce cache hashmap, the key is the nonce and the value is the time
awiteb marked this conversation as resolved
Review

This is not an issue, but I'm curios:

  • why not put the mutex here inside, instead of passing Mutex<NonceCache> around, we can use NonceCache and handle locking inside.
  • why not use tokio mutex in this, we would need to change many functions to be async of course, not sure which is better.
This is not an issue, but I'm curios: - why not put the mutex here inside, instead of passing `Mutex<NonceCache>` around, we can use `NonceCache` and handle locking inside. - why not use tokio mutex in this, we would need to change many functions to be `async` of course, not sure which is better.
Review

why not put the mutex here inside

Actually a good idea, I'll.

not sure which is better.

I'll see

> why not put the mutex here inside Actually a good idea, I'll. > not sure which is better. I'll see
cache: TokioMutex<HashMap<[u8; 16], i64>>,
}
impl NonceCache {
/// Creates new [`NonceCache`] instance, with the given cache limit
pub fn new(cache_limit: &Size) -> Self {
Self {
cache: TokioMutex::new(HashMap::with_capacity(
(cache_limit.as_bytes() - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
)),
}
}
/// Add a nonce to the cache, returns `true` if the nonce is added, `false`
/// if the nonce is already exist in the cache.
pub async fn add_nonce(&self, nonce: &[u8; 16]) -> bool {
let mut cache = self.cache.lock().await;
let now = Utc::now().timestamp();
cache.retain(|_, time| (now - *time) < 30);
if cache.len() == cache.capacity() {
log::warn!("Nonce cache limit reached, clearing 10% of the cache");
let num_to_remove = cache.capacity() / 10;
let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
for key in keys.iter().take(num_to_remove) {
cache.remove(key);
}
}
// We can use insert directly, but it's will update the value if the key is
// already exist so we need to check if the key is already exist or not
if cache.contains_key(nonce) {
return false;
}
cache.insert(*nonce, now);
true
}
}

View file

@ -14,9 +14,8 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>. // along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
use std::collections::HashMap; use std::env;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use std::{env, mem};
use oxidetalis_config::Config; use oxidetalis_config::Config;
use salvo::http::ResBody; use salvo::http::ResBody;
@ -24,18 +23,12 @@ use salvo::oapi::{Info, License};
use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer}; use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer};
use salvo::{catcher::Catcher, logging::Logger, prelude::*}; use salvo::{catcher::Catcher, logging::Logger, prelude::*};
use crate::nonce::NonceCache;
use crate::schemas::MessageSchema; use crate::schemas::MessageSchema;
use crate::{middlewares, NonceCache}; use crate::{middlewares, websocket};
mod user; mod user;
/// Size of each entry in the nonce cache
pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::<i16>();
/// Size of the hashmap itself without the entrys (48 bytes)
pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::<HashMap<u8, u8>>();
/// Name of the nonce cache size in the depot
pub(crate) const DEPOT_NONCE_CACHE_SIZE: &str = "NONCE_CACHE_SIZE";
pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) { pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) {
res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail")) res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail"))
.ok(); .ok();
@ -69,7 +62,7 @@ async fn handle_server_errors(res: &mut Response, ctrl: &mut FlowCtrl) {
err.brief.trim_end_matches('.'), err.brief.trim_end_matches('.'),
err.cause err.cause
.as_deref() .as_deref()
.map_or_else(|| "".to_owned(), ToString::to_string) .map_or_else(String::new, ToString::to_string)
.trim_end_matches('.') .trim_end_matches('.')
.split(':') .split(':')
.last() .last()
@ -129,23 +122,19 @@ fn route_openapi(config: &Config, router: Router) -> Router {
} }
pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service { pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
let nonce_cache_size = config.server.nonce_cache_size.as_bytes(); let nonce_cache: NonceCache = NonceCache::new(&config.server.nonce_cache_size);
let nonce_cache: NonceCache = Mutex::new(HashMap::with_capacity(
(nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
));
log::info!( log::info!(
"Nonce cache created with a capacity of {} ({})", "Nonce cache created with a capacity of {}",
(nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
config.server.nonce_cache_size config.server.nonce_cache_size
); );
let router = Router::new() let router = Router::new()
.push(Router::with_path("user").push(user::route())) .push(Router::with_path("user").push(user::route()))
.push(Router::with_path("ws").push(websocket::route()))
.hoop(middlewares::add_server_headers) .hoop(middlewares::add_server_headers)
.hoop(Logger::new()) .hoop(Logger::new())
.hoop( .hoop(
affix::inject(Arc::new(conn)) affix::inject(Arc::new(conn))
.insert(DEPOT_NONCE_CACHE_SIZE, Arc::new(nonce_cache_size))
.inject(Arc::new(config.clone())) .inject(Arc::new(config.clone()))
.inject(Arc::new(nonce_cache)), .inject(Arc::new(nonce_cache)),
); );

View file

@ -27,7 +27,7 @@ use oxidetalis_core::{
}; };
use salvo::Request; use salvo::Request;
use crate::{extensions::NonceCacheExt, NonceCache}; use crate::nonce::NonceCache;
/// Returns the postgres database url /// Returns the postgres database url
#[logcall] #[logcall]
@ -43,16 +43,12 @@ pub(crate) fn postgres_url(db_config: &Postgres) -> String {
} }
/// Returns true if the given nonce a valid nonce. /// Returns true if the given nonce a valid nonce.
pub(crate) fn is_valid_nonce( pub(crate) async fn is_valid_nonce(signature: &Signature, nonce_cache: &NonceCache) -> bool {
signature: &Signature,
nonce_cache: &NonceCache,
nonce_cache_limit: &usize,
) -> bool {
let new_timestamp = Utc::now() let new_timestamp = Utc::now()
.timestamp() .timestamp()
.checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64) .checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64)
.is_some_and(|n| n <= 20); .is_some_and(|n| n <= 20);
let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce(), nonce_cache_limit); let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce()).await;
new_timestamp && unused_nonce new_timestamp && unused_nonce
} }

View file

@ -0,0 +1,57 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Websocket errors
/// Result type of websocket
pub type WsResult<T> = Result<T, WsError>;
/// Websocket errors, returned in the websocket communication
#[derive(Debug)]
pub enum WsError {
/// The signature is invalid
InvalidSignature,
/// Message type must be text
NotTextMessage,
/// Invalid json data
InvalidJsonData,
/// Unknown client event
UnknownClientEvent,
}
impl WsError {
/// Returns error name
pub const fn name(&self) -> &'static str {
match self {
WsError::InvalidSignature => "InvalidSignature",
WsError::NotTextMessage => "NotTextMessage",
WsError::InvalidJsonData => "InvalidJsonData",
WsError::UnknownClientEvent => "UnknownClientEvent",
}
}
/// Returns the error reason
pub const fn reason(&self) -> &'static str {
match self {
WsError::InvalidSignature => "Invalid event signature",
WsError::NotTextMessage => "The websocket message must be text message",
WsError::InvalidJsonData => "Received invalid json data, the text must be valid json",
WsError::UnknownClientEvent => {
"Unknown client event, the event is not recognized by the server"
}
}
}
}

View file

@ -0,0 +1,60 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Events that the client send it
use oxidetalis_core::types::Signature;
use serde::{Deserialize, Serialize};
use crate::{nonce::NonceCache, utils};
/// Client websocket event
#[derive(Deserialize, Clone, Debug)]
pub struct ClientEvent {
awiteb marked this conversation as resolved
Review

The separation between the type and event data and then checking manually with is_of_type looks wrong.

Why not do it in a single single enum to host all events and data

#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "event", content = "data")]
pub enum ClientEventType {
    Ping { timestamp: u64 },
    Pong { timestamp: u64 },
}

#[derive(Serialize, Deserialize, Debug)]
pub struct ClientEvent {
    #[serde(flatten)]
    pub event_type: ClientEventType,
    pub signature: String,
}

There is duplicate now in ping/pong timestamp, but it would be better in the long run with more events and data.

The separation between the type and event data and then checking manually with `is_of_type` looks wrong. Why not do it in a single single enum to host all events and data ```rust #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "event", content = "data")] pub enum ClientEventType { Ping { timestamp: u64 }, Pong { timestamp: u64 }, } #[derive(Serialize, Deserialize, Debug)] pub struct ClientEvent { #[serde(flatten)] pub event_type: ClientEventType, pub signature: String, } ``` There is duplicate now in `ping/pong` `timestamp`, but it would be better in the long run with more events and data.
Review

Bro!!! I didn't know about tag-content, is amazing

Bro!!! I didn't know about [tag-content](https://serde.rs/container-attrs.html#tag--content), is amazing
Review

Is good, but we have a problem. We need the event data for the signature, but we can't extract it directly in this enum.

I written a function that extract the data by serializing the enum variant to serde_json::Value then extract the data from it, what do you think?

impl ClientEventType {
    /// Returns event data as json bytes
    pub fn data(&self) -> Vec<u8> {
        serde_json::to_vec(&serde_json::to_value(self).expect("can't fail")["data"])
            .expect("can't fail")
    }
}

This will output {"timestamp":4398} as bytes, we will make the signature from it

Is good, but we have a problem. We need the event data for the signature, but we can't extract it directly in this enum. I written a function that extract the data by serializing the enum variant to `serde_json::Value` then extract the `data` from it, what do you think? ```rust impl ClientEventType { /// Returns event data as json bytes pub fn data(&self) -> Vec<u8> { serde_json::to_vec(&serde_json::to_value(self).expect("can't fail")["data"]) .expect("can't fail") } } ``` This will output `{"timestamp":4398}` as bytes, we will make the signature from it
Review

I would use this

serde_json::to_value(&self.event_type).unwrap()["data"].to_string().into_bytes())

first, we don't need to serialize the whole thing as it would include the signature and then we will discard everything except for data, so better to serialize only the smallest needed parts.
second, we can directly use to_string from Value, and take the bytes out of it, I think its a bit better since to_vec would call the serializer for Value

I would use this ```rust serde_json::to_value(&self.event_type).unwrap()["data"].to_string().into_bytes()) ``` first, we don't need to serialize the whole thing as it would include the `signature` and then we will discard everything except for `data`, so better to serialize only the smallest needed parts. second, we can directly use `to_string` from `Value`, and take the bytes out of it, I think its a bit better since `to_vec` would call the serializer for `Value`
Review

small thing, maybe better to rename event_type to event? shorter and its not just the type, as it contain the data as well

small thing, maybe better to rename `event_type` to `event`? shorter and its not just the type, as it contain the data as well
Review

first, we don't need to serialize the whole thing as it would include the signature and then we will discard everything except for data

Right, the function above is for ClientEventType not Event.

I'll force push here today.

> first, we don't need to serialize the whole thing as it would include the signature and then we will discard everything except for data Right, the function above is for `ClientEventType` not `Event`. I'll force push here today.
Review

Ah yes, didn't notice

Ah yes, didn't notice
pub event: ClientEventType,
signature: Signature,
}
/// Client websocket event type
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
#[serde(rename_all = "PascalCase", tag = "event", content = "data")]
pub enum ClientEventType {
/// Ping event
Ping { timestamp: u64 },
/// Pong event
Pong { timestamp: u64 },
}
impl ClientEventType {
/// Returns event data as json bytes
pub fn data(&self) -> Vec<u8> {
serde_json::to_value(self).expect("can't fail")["data"]
.to_string()
.into_bytes()
awiteb marked this conversation as resolved
Review

this is different from ServerEventType, there, its using .to_string().into_bytes()

this is different from `ServerEventType`, there, its using `.to_string().into_bytes()`
Review

Forget it 😬

Forget it 😬
}
}
impl ClientEvent {
/// Verify the signature of the event
pub async fn verify_signature(
&self,
shared_secret: &[u8; 32],
nonce_cache: &NonceCache,
) -> bool {
utils::is_valid_nonce(&self.signature, nonce_cache).await
&& self.signature.verify(&self.event.data(), shared_secret)
}
}

View file

@ -0,0 +1,23 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Server and client websocket events
mod client;
mod server;
pub use client::*;
pub use server::*;

View file

@ -0,0 +1,117 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Events that the server send it
use std::marker::PhantomData;
use chrono::Utc;
use oxidetalis_core::{cipher::K256Secret, types::Signature};
use salvo::websocket::Message;
use serde::Serialize;
use crate::websocket::errors::WsError;
/// Signed marker, used to indicate that the event is signed
pub struct Signed;
/// Unsigned marker, used to indicate that the event is unsigned
pub struct Unsigned;
/// Server websocket event
#[derive(Serialize, Clone, Debug)]
pub struct ServerEvent<T> {
#[serde(flatten)]
event: ServerEventType,
signature: Signature,
#[serde(skip)]
phantom: PhantomData<T>,
}
/// server websocket event type
#[derive(Serialize, Clone, Eq, PartialEq, Debug)]
#[serde(rename_all = "PascalCase")]
pub enum ServerEventType {
/// Ping event
Ping { timestamp: u64 },
/// Pong event
Pong { timestamp: u64 },
/// Error event
Error {
name: &'static str,
reason: &'static str,
},
}
impl ServerEventType {
/// Returns event data as json bytes
pub fn data(&self) -> Vec<u8> {
serde_json::to_value(self).expect("can't fail")["data"]
.to_string()
.into_bytes()
}
}
impl ServerEvent<Unsigned> {
/// Creates new [`ServerEvent`]
pub fn new(event: ServerEventType) -> Self {
Self {
event,
signature: Signature::from([0u8; 56]),
phantom: PhantomData,
}
}
/// Creates ping event
pub fn ping() -> Self {
Self::new(ServerEventType::Ping {
timestamp: Utc::now().timestamp() as u64,
})
}
/// Creates pong event
pub fn pong() -> Self {
Self::new(ServerEventType::Pong {
timestamp: Utc::now().timestamp() as u64,
})
}
/// Sign the event
pub fn sign(self, shared_secret: &[u8; 32]) -> ServerEvent<Signed> {
ServerEvent::<Signed> {
signature: K256Secret::sign_with_shared_secret(
&serde_json::to_vec(&self.event.data()).expect("Can't fail"),
shared_secret,
),
event: self.event,
phantom: PhantomData,
}
}
}
impl From<&ServerEvent<Signed>> for Message {
fn from(value: &ServerEvent<Signed>) -> Self {
Message::text(serde_json::to_string(value).expect("This can't fail"))
}
}
impl From<WsError> for ServerEvent<Unsigned> {
fn from(err: WsError) -> Self {
ServerEvent::new(ServerEventType::Error {
name: err.name(),
reason: err.reason(),
})
}
}

View file

@ -0,0 +1,219 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
use std::{collections::HashMap, sync::Arc, time::Duration};
use chrono::Utc;
use errors::{WsError, WsResult};
use futures::{channel::mpsc, FutureExt, StreamExt, TryStreamExt};
use once_cell::sync::Lazy;
use oxidetalis_core::{cipher::K256Secret, types::PublicKey};
use salvo::{
handler,
http::StatusError,
websocket::{Message, WebSocket, WebSocketUpgrade},
Depot,
Request,
Response,
Router,
};
use tokio::{sync::RwLock, task::spawn as tokio_spawn, time::sleep as tokio_sleep};
mod errors;
mod events;
pub use events::*;
use uuid::Uuid;
use crate::{
extensions::{DepotExt, OnlineUsersExt},
middlewares,
nonce::NonceCache,
utils,
};
/// Online users type
pub type OnlineUsers = RwLock<HashMap<Uuid, SocketUserData>>;
/// List of online users, users that are connected to the server
// FIXME: Use `std::sync::LazyLock` after it becomes stable in `1.80.0`
static ONLINE_USERS: Lazy<OnlineUsers> = Lazy::new(OnlineUsers::default);
/// A user connected to the server
pub struct SocketUserData {
/// Sender to send messages to the user
pub sender: mpsc::UnboundedSender<salvo::Result<Message>>,
/// User public key
pub public_key: PublicKey,
/// Time that the user pinged at
pub pinged_at: chrono::DateTime<Utc>,
/// Time that the user ponged at
pub ponged_at: chrono::DateTime<Utc>,
/// User shared secret
pub shared_secret: [u8; 32],
}
impl SocketUserData {
/// Creates new [`SocketUserData`]
pub fn new(
public_key: PublicKey,
shared_secret: [u8; 32],
sender: mpsc::UnboundedSender<salvo::Result<Message>>,
) -> Self {
let now = Utc::now();
Self {
sender,
public_key,
shared_secret,
pinged_at: now,
ponged_at: now,
}
}
}
/// WebSocket handler, that handles the user connection
#[handler]
pub async fn user_connected(
req: &mut Request,
res: &mut Response,
depot: &Depot,
) -> Result<(), StatusError> {
let nonce_cache = depot.nonce_cache();
let public_key =
awiteb marked this conversation as resolved
Review

something I noticed, this nonce_limit is just passed everywhere without a need.

in NonceCacheExt::add_nonce, it takes limit, even though both of these items comes from the depot and just passed around in every function, handle_socket -> handle_ws_msg -> verify_signature -> ...... -> add_nonce

A better approach is to make NonceCache a specific struct (not type) that holds both the hashmap and the limit

something I noticed, this `nonce_limit` is just passed everywhere without a need. in `NonceCacheExt::add_nonce`, it takes `limit`, even though both of these items comes from the depot and just passed around in every function, `handle_socket -> handle_ws_msg -> verify_signature -> ...... -> add_nonce` A better approach is to make `NonceCache` a specific struct (not type) that holds both the hashmap and the limit
utils::extract_public_key(req).expect("The public key was checked in the middleware");
// FIXME: The config should hold `K256Secret` not `PrivateKey`
let shared_secret =
K256Secret::from_privkey(&depot.config().server.private_key).shared_secret(&public_key);
WebSocketUpgrade::new()
.upgrade(req, res, move |ws| {
handle_socket(ws, nonce_cache, public_key, shared_secret)
})
.await
}
/// Handle the websocket connection
async fn handle_socket(
ws: WebSocket,
nonce_cache: Arc<NonceCache>,
user_public_key: PublicKey,
user_shared_secret: [u8; 32],
) {
let (user_ws_sender, mut user_ws_receiver) = ws.split();
let (sender, receiver) = mpsc::unbounded();
let receiver = receiver.into_stream();
let fut = receiver.forward(user_ws_sender).map(|result| {
if let Err(err) = result {
awiteb marked this conversation as resolved
Review

sender is already Clone and is implemented with Arc

sender is already `Clone` and is implemented with `Arc`
log::error!("websocket send error: {err}");
}
});
tokio_spawn(fut);
let conn_id = Uuid::new_v4();
let user = SocketUserData::new(user_public_key, user_shared_secret, sender.clone());
ONLINE_USERS.add_user(&conn_id, user).await;
log::info!("New user connected: ConnId(={conn_id}) PublicKey(={user_public_key})");
let fut = async move {
while let Some(Ok(msg)) = user_ws_receiver.next().await {
match handle_ws_msg(msg, &nonce_cache, &user_shared_secret).await {
Ok(event) => {
if let Some(server_event) = handle_events(event, &conn_id).await {
if let Err(err) = sender.unbounded_send(Ok(Message::from(
&server_event.sign(&user_shared_secret),
))) {
log::error!("Websocket Error: {err}");
break;
}
};
}
Err(err) => {
if let Err(err) = sender.unbounded_send(Ok(Message::from(
&ServerEvent::from(err).sign(&user_shared_secret),
))) {
log::error!("Websocket Error: {err}");
break;
};
}
};
}
user_disconnected(&conn_id, &user_public_key).await;
};
tokio_spawn(fut);
}
/// Handle websocket msg
async fn handle_ws_msg(
msg: Message,
nonce_cache: &NonceCache,
shared_secret: &[u8; 32],
) -> WsResult<ClientEvent> {
let Ok(text) = msg.to_str() else {
return Err(WsError::NotTextMessage);
};
let event = serde_json::from_str::<ClientEvent>(text).map_err(|err| {
if err.is_data() {
WsError::UnknownClientEvent
} else {
WsError::InvalidJsonData
}
})?;
if !event.verify_signature(shared_secret, nonce_cache).await {
return Err(WsError::InvalidSignature);
}
Ok(event)
}
/// Handle user events, and return the server event if needed
async fn handle_events(event: ClientEvent, conn_id: &Uuid) -> Option<ServerEvent<Unsigned>> {
match &event.event {
ClientEventType::Ping { .. } => Some(ServerEvent::pong()),
ClientEventType::Pong { .. } => {
ONLINE_USERS.update_pong(conn_id).await;
None
}
}
}
/// Handle user disconnected
async fn user_disconnected(conn_id: &Uuid, public_key: &PublicKey) {
ONLINE_USERS.remove_user(conn_id).await;
log::debug!("User disconnect: ConnId(={conn_id}) PublicKey(={public_key})");
}
pub fn route() -> Router {
let users_pinger = async {
/// Seconds to wait for pongs, before disconnecting the user
const WAIT_FOR_PONGS_SECS: u32 = 10;
/// Seconds to sleep between pings (10 minutes)
const SLEEP_SECS: u32 = 60 * 10;
loop {
log::debug!("Start pinging online users");
ONLINE_USERS.ping_all().await;
tokio_sleep(Duration::from_secs(u64::from(WAIT_FOR_PONGS_SECS))).await;
ONLINE_USERS.disconnect_inactive_users().await;
log::debug!("Done pinging online users and disconnected inactive ones");
tokio_sleep(Duration::from_secs(u64::from(SLEEP_SECS))).await;
awiteb marked this conversation as resolved
Review

I think both of these is better to be config as well, more configuration options for admin, 10 mins is good enough of course and probably won't change by most ppl

I think both of these is better to be config as well, more configuration options for admin, 10 mins is good enough of course and probably won't change by most ppl
Review

more configuration options for admin

I wanted to implement it today, but I felt that this is not something that should be determined by the server administrator? I've never seen a server provide it.

I will not implement it. I don't think that the server administrator must specify the period during which the server sends pings.

> more configuration options for admin I wanted to implement it today, but I felt that this is not something that should be determined by the server administrator? I've never seen a server provide it. I will not implement it. I don't think that the server administrator must specify the period during which the server sends pings.
}
};
tokio_spawn(users_pinger);
Router::new()
.push(Router::with_path("chat").get(user_connected))
.hoop(middlewares::signature_check)
.hoop(middlewares::public_key_check)
}