feat: Initialize server websocket #8

Merged
awiteb merged 4 commits from awiteb/init-websocket into master 2024-07-06 13:35:10 +02:00 AGit
8 changed files with 96 additions and 86 deletions
Showing only changes of commit d442e73ed7 - Show all commits

View file

@ -24,9 +24,8 @@ use sea_orm::DatabaseConnection;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
routes::DEPOT_NONCE_CACHE_SIZE, nonce::NonceCache,
websocket::{OnlineUsers, ServerEvent, SocketUserData}, websocket::{OnlineUsers, ServerEvent, SocketUserData},
NonceCache,
}; };
/// Extension trait for the Depot. /// Extension trait for the Depot.
@ -37,15 +36,6 @@ pub trait DepotExt {
fn config(&self) -> &Config; fn config(&self) -> &Config;
/// Retutns the nonce cache /// Retutns the nonce cache
fn nonce_cache(&self) -> Arc<NonceCache>; fn nonce_cache(&self) -> Arc<NonceCache>;
/// Returns the size of the nonce cache
fn nonce_cache_size(&self) -> &usize;
}
/// Extension trait for the nonce cache.
pub trait NonceCacheExt {
/// Add a nonce to the cache, returns `true` if the nonce is added, `false`
/// if the nonce is already exist in the cache.
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool;
} }
/// Extension trait for online websocket users /// Extension trait for online websocket users
@ -82,38 +72,6 @@ impl DepotExt for Depot {
.expect("Nonce cache not found"), .expect("Nonce cache not found"),
) )
} }
fn nonce_cache_size(&self) -> &usize {
let s: &Arc<usize> = self
.get(DEPOT_NONCE_CACHE_SIZE)
.expect("Nonce cache size not found");
s.as_ref()
}
}
impl NonceCacheExt for &NonceCache {
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool {
let mut cache = self.lock().expect("Nonce cache lock poisoned, aborting...");
let now = Utc::now().timestamp();
cache.retain(|_, time| (now - *time) < 30);
if &cache.len() >= limit {
log::warn!("Nonce cache limit reached, clearing 10% of the cache");
let num_to_remove = limit / 10;
let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
for key in keys.iter().take(num_to_remove) {
cache.remove(key);
}
}
// We can use insert directly, but it's will update the value if the key is
// already exist so we need to check if the key is already exist or not
if cache.contains_key(nonce) {
return false;
}
cache.insert(*nonce, now);
true
}
} }
impl OnlineUsersExt for OnlineUsers { impl OnlineUsersExt for OnlineUsers {

View file

@ -17,7 +17,7 @@
#![doc = include_str!("../../../README.md")] #![doc = include_str!("../../../README.md")]
#![warn(missing_docs, unsafe_code)] #![warn(missing_docs, unsafe_code)]
use std::{collections::HashMap, process::ExitCode, sync::Mutex}; use std::process::ExitCode;
use oxidetalis_config::{CliArgs, Parser}; use oxidetalis_config::{CliArgs, Parser};
use oxidetalis_migrations::MigratorTrait; use oxidetalis_migrations::MigratorTrait;
@ -27,14 +27,12 @@ mod database;
mod errors; mod errors;
mod extensions; mod extensions;
mod middlewares; mod middlewares;
mod nonce;
mod routes; mod routes;
mod schemas; mod schemas;
mod utils; mod utils;
mod websocket; mod websocket;
/// Nonce cache type, used to store nonces for a certain amount of time
pub type NonceCache = Mutex<HashMap<[u8; 16], i64>>;
async fn try_main() -> errors::Result<()> { async fn try_main() -> errors::Result<()> {
pretty_env_logger::init_timed(); pretty_env_logger::init_timed();
@ -50,6 +48,7 @@ async fn try_main() -> errors::Result<()> {
let local_addr = format!("{}:{}", config.server.host, config.server.port); let local_addr = format!("{}:{}", config.server.host, config.server.port);
let acceptor = TcpListener::new(&local_addr).bind().await; let acceptor = TcpListener::new(&local_addr).bind().await;
log::info!("Server listening on http://{local_addr}"); log::info!("Server listening on http://{local_addr}");
log::info!("Chat websocket on ws://{local_addr}/ws/chat");
if config.openapi.enable { if config.openapi.enable {
log::info!( log::info!(
"The openapi schema is available at http://{local_addr}{}", "The openapi schema is available at http://{local_addr}{}",

View file

@ -70,7 +70,7 @@ pub async fn signature_check(
} }
}; };
if !utils::is_valid_nonce(&signature, &depot.nonce_cache(), depot.nonce_cache_size()) if !utils::is_valid_nonce(&signature, &depot.nonce_cache()).await
|| !utils::is_valid_signature( || !utils::is_valid_signature(
&sender_public_key, &sender_public_key,
&depot.config().server.private_key, &depot.config().server.private_key,

View file

@ -0,0 +1,73 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Nonce cache implementation
use std::{collections::HashMap, mem};
use chrono::Utc;
use oxidetalis_core::types::Size;
use tokio::sync::Mutex as TokioMutex;
/// Size of each entry in the nonce cache
pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::<i16>();
/// Size of the hashmap itself without the entrys (48 bytes)
pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::<HashMap<u8, u8>>();
/// Nonce cache struct, used to store nonces for a short period of time
/// to prevent replay attacks, each nonce has a 30 seconds lifetime.
///
/// The cache will remove first 10% nonces if the cache limit is reached.
pub struct NonceCache {
/// The nonce cache hashmap, the key is the nonce and the value is the time
awiteb marked this conversation as resolved
Review

This is not an issue, but I'm curios:

  • why not put the mutex here inside, instead of passing Mutex<NonceCache> around, we can use NonceCache and handle locking inside.
  • why not use tokio mutex in this, we would need to change many functions to be async of course, not sure which is better.
This is not an issue, but I'm curios: - why not put the mutex here inside, instead of passing `Mutex<NonceCache>` around, we can use `NonceCache` and handle locking inside. - why not use tokio mutex in this, we would need to change many functions to be `async` of course, not sure which is better.
Review

why not put the mutex here inside

Actually a good idea, I'll.

not sure which is better.

I'll see

> why not put the mutex here inside Actually a good idea, I'll. > not sure which is better. I'll see
cache: TokioMutex<HashMap<[u8; 16], i64>>,
}
impl NonceCache {
/// Creates new [`NonceCache`] instance, with the given cache limit
pub fn new(cache_limit: &Size) -> Self {
Self {
cache: TokioMutex::new(HashMap::with_capacity(
(cache_limit.as_bytes() - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
)),
}
}
/// Add a nonce to the cache, returns `true` if the nonce is added, `false`
/// if the nonce is already exist in the cache.
pub async fn add_nonce(&self, nonce: &[u8; 16]) -> bool {
let mut cache = self.cache.lock().await;
let now = Utc::now().timestamp();
cache.retain(|_, time| (now - *time) < 30);
if cache.len() == cache.capacity() {
log::warn!("Nonce cache limit reached, clearing 10% of the cache");
let num_to_remove = cache.capacity() / 10;
let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
for key in keys.iter().take(num_to_remove) {
cache.remove(key);
}
}
// We can use insert directly, but it's will update the value if the key is
// already exist so we need to check if the key is already exist or not
if cache.contains_key(nonce) {
return false;
}
cache.insert(*nonce, now);
true
}
}

View file

@ -14,9 +14,8 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>. // along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
use std::collections::HashMap; use std::env;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use std::{env, mem};
use oxidetalis_config::Config; use oxidetalis_config::Config;
use salvo::http::ResBody; use salvo::http::ResBody;
@ -24,18 +23,12 @@ use salvo::oapi::{Info, License};
use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer}; use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer};
use salvo::{catcher::Catcher, logging::Logger, prelude::*}; use salvo::{catcher::Catcher, logging::Logger, prelude::*};
use crate::nonce::NonceCache;
use crate::schemas::MessageSchema; use crate::schemas::MessageSchema;
use crate::{middlewares, websocket, NonceCache}; use crate::{middlewares, websocket};
mod user; mod user;
/// Size of each entry in the nonce cache
pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::<i16>();
/// Size of the hashmap itself without the entrys (48 bytes)
pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::<HashMap<u8, u8>>();
/// Name of the nonce cache size in the depot
pub(crate) const DEPOT_NONCE_CACHE_SIZE: &str = "NONCE_CACHE_SIZE";
pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) { pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) {
res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail")) res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail"))
.ok(); .ok();
@ -69,7 +62,7 @@ async fn handle_server_errors(res: &mut Response, ctrl: &mut FlowCtrl) {
err.brief.trim_end_matches('.'), err.brief.trim_end_matches('.'),
err.cause err.cause
.as_deref() .as_deref()
.map_or_else(|| "".to_owned(), ToString::to_string) .map_or_else(String::new, ToString::to_string)
.trim_end_matches('.') .trim_end_matches('.')
.split(':') .split(':')
.last() .last()
@ -129,13 +122,9 @@ fn route_openapi(config: &Config, router: Router) -> Router {
} }
pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service { pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
let nonce_cache_size = config.server.nonce_cache_size.as_bytes(); let nonce_cache: NonceCache = NonceCache::new(&config.server.nonce_cache_size);
let nonce_cache: NonceCache = Mutex::new(HashMap::with_capacity(
(nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
));
log::info!( log::info!(
"Nonce cache created with a capacity of {} ({})", "Nonce cache created with a capacity of {}",
(nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
config.server.nonce_cache_size config.server.nonce_cache_size
); );
@ -146,7 +135,6 @@ pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
.hoop(Logger::new()) .hoop(Logger::new())
.hoop( .hoop(
affix::inject(Arc::new(conn)) affix::inject(Arc::new(conn))
.insert(DEPOT_NONCE_CACHE_SIZE, Arc::new(nonce_cache_size))
.inject(Arc::new(config.clone())) .inject(Arc::new(config.clone()))
.inject(Arc::new(nonce_cache)), .inject(Arc::new(nonce_cache)),
); );

View file

@ -27,7 +27,7 @@ use oxidetalis_core::{
}; };
use salvo::Request; use salvo::Request;
use crate::{extensions::NonceCacheExt, NonceCache}; use crate::nonce::NonceCache;
/// Returns the postgres database url /// Returns the postgres database url
#[logcall] #[logcall]
@ -43,16 +43,12 @@ pub(crate) fn postgres_url(db_config: &Postgres) -> String {
} }
/// Returns true if the given nonce a valid nonce. /// Returns true if the given nonce a valid nonce.
pub(crate) fn is_valid_nonce( pub(crate) async fn is_valid_nonce(signature: &Signature, nonce_cache: &NonceCache) -> bool {
signature: &Signature,
nonce_cache: &NonceCache,
nonce_cache_limit: &usize,
) -> bool {
let new_timestamp = Utc::now() let new_timestamp = Utc::now()
.timestamp() .timestamp()
.checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64) .checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64)
.is_some_and(|n| n <= 20); .is_some_and(|n| n <= 20);
let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce(), nonce_cache_limit); let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce()).await;
new_timestamp && unused_nonce new_timestamp && unused_nonce
} }

View file

@ -19,7 +19,7 @@
use oxidetalis_core::types::Signature; use oxidetalis_core::types::Signature;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{utils, NonceCache}; use crate::{nonce::NonceCache, utils};
/// Client websocket event /// Client websocket event
#[derive(Deserialize, Clone, Debug)] #[derive(Deserialize, Clone, Debug)]
@ -49,13 +49,12 @@ impl ClientEventType {
impl ClientEvent { impl ClientEvent {
/// Verify the signature of the event /// Verify the signature of the event
pub fn verify_signature( pub async fn verify_signature(
&self, &self,
shared_secret: &[u8; 32], shared_secret: &[u8; 32],
nonce_cache: &NonceCache, nonce_cache: &NonceCache,
nonce_limit: &usize,
) -> bool { ) -> bool {
utils::is_valid_nonce(&self.signature, nonce_cache, nonce_limit) utils::is_valid_nonce(&self.signature, nonce_cache).await
&& self.signature.verify(&self.event.data(), shared_secret) && self.signature.verify(&self.event.data(), shared_secret)
} }
} }

View file

@ -41,8 +41,8 @@ use uuid::Uuid;
use crate::{ use crate::{
extensions::{DepotExt, OnlineUsersExt}, extensions::{DepotExt, OnlineUsersExt},
middlewares, middlewares,
nonce::NonceCache,
utils, utils,
NonceCache,
}; };
/// Online users type /// Online users type
@ -92,7 +92,6 @@ pub async fn user_connected(
depot: &Depot, depot: &Depot,
) -> Result<(), StatusError> { ) -> Result<(), StatusError> {
let nonce_cache = depot.nonce_cache(); let nonce_cache = depot.nonce_cache();
let nonce_limit = *depot.nonce_cache_size();
let public_key = let public_key =
awiteb marked this conversation as resolved
Review

something I noticed, this nonce_limit is just passed everywhere without a need.

in NonceCacheExt::add_nonce, it takes limit, even though both of these items comes from the depot and just passed around in every function, handle_socket -> handle_ws_msg -> verify_signature -> ...... -> add_nonce

A better approach is to make NonceCache a specific struct (not type) that holds both the hashmap and the limit

something I noticed, this `nonce_limit` is just passed everywhere without a need. in `NonceCacheExt::add_nonce`, it takes `limit`, even though both of these items comes from the depot and just passed around in every function, `handle_socket -> handle_ws_msg -> verify_signature -> ...... -> add_nonce` A better approach is to make `NonceCache` a specific struct (not type) that holds both the hashmap and the limit
utils::extract_public_key(req).expect("The public key was checked in the middleware"); utils::extract_public_key(req).expect("The public key was checked in the middleware");
// FIXME: The config should hold `K256Secret` not `PrivateKey` // FIXME: The config should hold `K256Secret` not `PrivateKey`
@ -101,7 +100,7 @@ pub async fn user_connected(
WebSocketUpgrade::new() WebSocketUpgrade::new()
.upgrade(req, res, move |ws| { .upgrade(req, res, move |ws| {
handle_socket(ws, nonce_cache, nonce_limit, public_key, shared_secret) handle_socket(ws, nonce_cache, public_key, shared_secret)
}) })
.await .await
} }
@ -110,7 +109,6 @@ pub async fn user_connected(
async fn handle_socket( async fn handle_socket(
ws: WebSocket, ws: WebSocket,
nonce_cache: Arc<NonceCache>, nonce_cache: Arc<NonceCache>,
nonce_limit: usize,
user_public_key: PublicKey, user_public_key: PublicKey,
user_shared_secret: [u8; 32], user_shared_secret: [u8; 32],
) { ) {
@ -131,7 +129,7 @@ async fn handle_socket(
let fut = async move { let fut = async move {
while let Some(Ok(msg)) = user_ws_receiver.next().await { while let Some(Ok(msg)) = user_ws_receiver.next().await {
match handle_ws_msg(msg, &nonce_cache, &nonce_limit, &user_shared_secret) { match handle_ws_msg(msg, &nonce_cache, &user_shared_secret).await {
Ok(event) => { Ok(event) => {
if let Some(server_event) = handle_events(event, &conn_id).await { if let Some(server_event) = handle_events(event, &conn_id).await {
if let Err(err) = sender.unbounded_send(Ok(Message::from( if let Err(err) = sender.unbounded_send(Ok(Message::from(
@ -158,10 +156,9 @@ async fn handle_socket(
} }
/// Handle websocket msg /// Handle websocket msg
fn handle_ws_msg( async fn handle_ws_msg(
msg: Message, msg: Message,
nonce_cache: &NonceCache, nonce_cache: &NonceCache,
nonce_limit: &usize,
shared_secret: &[u8; 32], shared_secret: &[u8; 32],
) -> WsResult<ClientEvent> { ) -> WsResult<ClientEvent> {
let Ok(text) = msg.to_str() else { let Ok(text) = msg.to_str() else {
@ -174,7 +171,7 @@ fn handle_ws_msg(
WsError::InvalidJsonData WsError::InvalidJsonData
} }
})?; })?;
if !event.verify_signature(shared_secret, nonce_cache, nonce_limit) { if !event.verify_signature(shared_secret, nonce_cache).await {
return Err(WsError::InvalidSignature); return Err(WsError::InvalidSignature);
} }
Ok(event) Ok(event)