change: Remove the nonce cache limit

Don't pass the nonce cache limit everywhere, we allocate the hashmap capacity
with it, so we can use the hashmap capacity directly.

Also refactor the `NonceCache` type, make it better without extension trait.

Suggested-by: Amjad Alsharafi <me@amjad.alsharafi.dev>
Reviewed-by: Amjad Alsharafi <me@amjad.alsharafi.dev>
Signed-off-by: Awiteb <a@4rs.nl>
This commit is contained in:
Awiteb 2024-07-06 01:25:06 +03:00
parent c0d5efe0c3
commit d442e73ed7
Signed by: awiteb
GPG key ID: 3F6B55640AA6682F
8 changed files with 96 additions and 86 deletions

View file

@ -24,9 +24,8 @@ use sea_orm::DatabaseConnection;
use uuid::Uuid; use uuid::Uuid;
use crate::{ use crate::{
routes::DEPOT_NONCE_CACHE_SIZE, nonce::NonceCache,
websocket::{OnlineUsers, ServerEvent, SocketUserData}, websocket::{OnlineUsers, ServerEvent, SocketUserData},
NonceCache,
}; };
/// Extension trait for the Depot. /// Extension trait for the Depot.
@ -37,15 +36,6 @@ pub trait DepotExt {
fn config(&self) -> &Config; fn config(&self) -> &Config;
/// Retutns the nonce cache /// Retutns the nonce cache
fn nonce_cache(&self) -> Arc<NonceCache>; fn nonce_cache(&self) -> Arc<NonceCache>;
/// Returns the size of the nonce cache
fn nonce_cache_size(&self) -> &usize;
}
/// Extension trait for the nonce cache.
pub trait NonceCacheExt {
/// Add a nonce to the cache, returns `true` if the nonce is added, `false`
/// if the nonce is already exist in the cache.
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool;
} }
/// Extension trait for online websocket users /// Extension trait for online websocket users
@ -82,38 +72,6 @@ impl DepotExt for Depot {
.expect("Nonce cache not found"), .expect("Nonce cache not found"),
) )
} }
fn nonce_cache_size(&self) -> &usize {
let s: &Arc<usize> = self
.get(DEPOT_NONCE_CACHE_SIZE)
.expect("Nonce cache size not found");
s.as_ref()
}
}
impl NonceCacheExt for &NonceCache {
fn add_nonce(&self, nonce: &[u8; 16], limit: &usize) -> bool {
let mut cache = self.lock().expect("Nonce cache lock poisoned, aborting...");
let now = Utc::now().timestamp();
cache.retain(|_, time| (now - *time) < 30);
if &cache.len() >= limit {
log::warn!("Nonce cache limit reached, clearing 10% of the cache");
let num_to_remove = limit / 10;
let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
for key in keys.iter().take(num_to_remove) {
cache.remove(key);
}
}
// We can use insert directly, but it's will update the value if the key is
// already exist so we need to check if the key is already exist or not
if cache.contains_key(nonce) {
return false;
}
cache.insert(*nonce, now);
true
}
} }
impl OnlineUsersExt for OnlineUsers { impl OnlineUsersExt for OnlineUsers {

View file

@ -17,7 +17,7 @@
#![doc = include_str!("../../../README.md")] #![doc = include_str!("../../../README.md")]
#![warn(missing_docs, unsafe_code)] #![warn(missing_docs, unsafe_code)]
use std::{collections::HashMap, process::ExitCode, sync::Mutex}; use std::process::ExitCode;
use oxidetalis_config::{CliArgs, Parser}; use oxidetalis_config::{CliArgs, Parser};
use oxidetalis_migrations::MigratorTrait; use oxidetalis_migrations::MigratorTrait;
@ -27,14 +27,12 @@ mod database;
mod errors; mod errors;
mod extensions; mod extensions;
mod middlewares; mod middlewares;
mod nonce;
mod routes; mod routes;
mod schemas; mod schemas;
mod utils; mod utils;
mod websocket; mod websocket;
/// Nonce cache type, used to store nonces for a certain amount of time
pub type NonceCache = Mutex<HashMap<[u8; 16], i64>>;
async fn try_main() -> errors::Result<()> { async fn try_main() -> errors::Result<()> {
pretty_env_logger::init_timed(); pretty_env_logger::init_timed();
@ -50,6 +48,7 @@ async fn try_main() -> errors::Result<()> {
let local_addr = format!("{}:{}", config.server.host, config.server.port); let local_addr = format!("{}:{}", config.server.host, config.server.port);
let acceptor = TcpListener::new(&local_addr).bind().await; let acceptor = TcpListener::new(&local_addr).bind().await;
log::info!("Server listening on http://{local_addr}"); log::info!("Server listening on http://{local_addr}");
log::info!("Chat websocket on ws://{local_addr}/ws/chat");
if config.openapi.enable { if config.openapi.enable {
log::info!( log::info!(
"The openapi schema is available at http://{local_addr}{}", "The openapi schema is available at http://{local_addr}{}",

View file

@ -70,7 +70,7 @@ pub async fn signature_check(
} }
}; };
if !utils::is_valid_nonce(&signature, &depot.nonce_cache(), depot.nonce_cache_size()) if !utils::is_valid_nonce(&signature, &depot.nonce_cache()).await
|| !utils::is_valid_signature( || !utils::is_valid_signature(
&sender_public_key, &sender_public_key,
&depot.config().server.private_key, &depot.config().server.private_key,

View file

@ -0,0 +1,73 @@
// OxideTalis Messaging Protocol homeserver implementation
// Copyright (C) 2024 OxideTalis Developers <otmp@4rs.nl>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
//! Nonce cache implementation
use std::{collections::HashMap, mem};
use chrono::Utc;
use oxidetalis_core::types::Size;
use tokio::sync::Mutex as TokioMutex;
/// Size of each entry in the nonce cache
pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::<i16>();
/// Size of the hashmap itself without the entrys (48 bytes)
pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::<HashMap<u8, u8>>();
/// Nonce cache struct, used to store nonces for a short period of time
/// to prevent replay attacks, each nonce has a 30 seconds lifetime.
///
/// The cache will remove first 10% nonces if the cache limit is reached.
pub struct NonceCache {
/// The nonce cache hashmap, the key is the nonce and the value is the time
cache: TokioMutex<HashMap<[u8; 16], i64>>,
}
impl NonceCache {
/// Creates new [`NonceCache`] instance, with the given cache limit
pub fn new(cache_limit: &Size) -> Self {
Self {
cache: TokioMutex::new(HashMap::with_capacity(
(cache_limit.as_bytes() - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
)),
}
}
/// Add a nonce to the cache, returns `true` if the nonce is added, `false`
/// if the nonce is already exist in the cache.
pub async fn add_nonce(&self, nonce: &[u8; 16]) -> bool {
let mut cache = self.cache.lock().await;
let now = Utc::now().timestamp();
cache.retain(|_, time| (now - *time) < 30);
if cache.len() == cache.capacity() {
log::warn!("Nonce cache limit reached, clearing 10% of the cache");
let num_to_remove = cache.capacity() / 10;
let keys: Vec<[u8; 16]> = cache.keys().copied().collect();
for key in keys.iter().take(num_to_remove) {
cache.remove(key);
}
}
// We can use insert directly, but it's will update the value if the key is
// already exist so we need to check if the key is already exist or not
if cache.contains_key(nonce) {
return false;
}
cache.insert(*nonce, now);
true
}
}

View file

@ -14,9 +14,8 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>. // along with this program. If not, see <https://gnu.org/licenses/agpl-3.0>.
use std::collections::HashMap; use std::env;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use std::{env, mem};
use oxidetalis_config::Config; use oxidetalis_config::Config;
use salvo::http::ResBody; use salvo::http::ResBody;
@ -24,18 +23,12 @@ use salvo::oapi::{Info, License};
use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer}; use salvo::rate_limiter::{BasicQuota, FixedGuard, MokaStore, RateLimiter, RemoteIpIssuer};
use salvo::{catcher::Catcher, logging::Logger, prelude::*}; use salvo::{catcher::Catcher, logging::Logger, prelude::*};
use crate::nonce::NonceCache;
use crate::schemas::MessageSchema; use crate::schemas::MessageSchema;
use crate::{middlewares, websocket, NonceCache}; use crate::{middlewares, websocket};
mod user; mod user;
/// Size of each entry in the nonce cache
pub(crate) const NONCE_ENTRY_SIZE: usize = mem::size_of::<[u8; 16]>() + mem::size_of::<i16>();
/// Size of the hashmap itself without the entrys (48 bytes)
pub(crate) const HASH_MAP_SIZE: usize = mem::size_of::<HashMap<u8, u8>>();
/// Name of the nonce cache size in the depot
pub(crate) const DEPOT_NONCE_CACHE_SIZE: &str = "NONCE_CACHE_SIZE";
pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) { pub fn write_json_body(res: &mut Response, json_body: impl serde::Serialize) {
res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail")) res.write_body(serde_json::to_string(&json_body).expect("Json serialization can't be fail"))
.ok(); .ok();
@ -69,7 +62,7 @@ async fn handle_server_errors(res: &mut Response, ctrl: &mut FlowCtrl) {
err.brief.trim_end_matches('.'), err.brief.trim_end_matches('.'),
err.cause err.cause
.as_deref() .as_deref()
.map_or_else(|| "".to_owned(), ToString::to_string) .map_or_else(String::new, ToString::to_string)
.trim_end_matches('.') .trim_end_matches('.')
.split(':') .split(':')
.last() .last()
@ -129,13 +122,9 @@ fn route_openapi(config: &Config, router: Router) -> Router {
} }
pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service { pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
let nonce_cache_size = config.server.nonce_cache_size.as_bytes(); let nonce_cache: NonceCache = NonceCache::new(&config.server.nonce_cache_size);
let nonce_cache: NonceCache = Mutex::new(HashMap::with_capacity(
(nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
));
log::info!( log::info!(
"Nonce cache created with a capacity of {} ({})", "Nonce cache created with a capacity of {}",
(nonce_cache_size - HASH_MAP_SIZE) / NONCE_ENTRY_SIZE,
config.server.nonce_cache_size config.server.nonce_cache_size
); );
@ -146,7 +135,6 @@ pub fn service(conn: sea_orm::DatabaseConnection, config: &Config) -> Service {
.hoop(Logger::new()) .hoop(Logger::new())
.hoop( .hoop(
affix::inject(Arc::new(conn)) affix::inject(Arc::new(conn))
.insert(DEPOT_NONCE_CACHE_SIZE, Arc::new(nonce_cache_size))
.inject(Arc::new(config.clone())) .inject(Arc::new(config.clone()))
.inject(Arc::new(nonce_cache)), .inject(Arc::new(nonce_cache)),
); );

View file

@ -27,7 +27,7 @@ use oxidetalis_core::{
}; };
use salvo::Request; use salvo::Request;
use crate::{extensions::NonceCacheExt, NonceCache}; use crate::nonce::NonceCache;
/// Returns the postgres database url /// Returns the postgres database url
#[logcall] #[logcall]
@ -43,16 +43,12 @@ pub(crate) fn postgres_url(db_config: &Postgres) -> String {
} }
/// Returns true if the given nonce a valid nonce. /// Returns true if the given nonce a valid nonce.
pub(crate) fn is_valid_nonce( pub(crate) async fn is_valid_nonce(signature: &Signature, nonce_cache: &NonceCache) -> bool {
signature: &Signature,
nonce_cache: &NonceCache,
nonce_cache_limit: &usize,
) -> bool {
let new_timestamp = Utc::now() let new_timestamp = Utc::now()
.timestamp() .timestamp()
.checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64) .checked_sub(u64::from_be_bytes(*signature.timestamp()) as i64)
.is_some_and(|n| n <= 20); .is_some_and(|n| n <= 20);
let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce(), nonce_cache_limit); let unused_nonce = new_timestamp && nonce_cache.add_nonce(signature.nonce()).await;
new_timestamp && unused_nonce new_timestamp && unused_nonce
} }

View file

@ -19,7 +19,7 @@
use oxidetalis_core::types::Signature; use oxidetalis_core::types::Signature;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{utils, NonceCache}; use crate::{nonce::NonceCache, utils};
/// Client websocket event /// Client websocket event
#[derive(Deserialize, Clone, Debug)] #[derive(Deserialize, Clone, Debug)]
@ -49,13 +49,12 @@ impl ClientEventType {
impl ClientEvent { impl ClientEvent {
/// Verify the signature of the event /// Verify the signature of the event
pub fn verify_signature( pub async fn verify_signature(
&self, &self,
shared_secret: &[u8; 32], shared_secret: &[u8; 32],
nonce_cache: &NonceCache, nonce_cache: &NonceCache,
nonce_limit: &usize,
) -> bool { ) -> bool {
utils::is_valid_nonce(&self.signature, nonce_cache, nonce_limit) utils::is_valid_nonce(&self.signature, nonce_cache).await
&& self.signature.verify(&self.event.data(), shared_secret) && self.signature.verify(&self.event.data(), shared_secret)
} }
} }

View file

@ -41,8 +41,8 @@ use uuid::Uuid;
use crate::{ use crate::{
extensions::{DepotExt, OnlineUsersExt}, extensions::{DepotExt, OnlineUsersExt},
middlewares, middlewares,
nonce::NonceCache,
utils, utils,
NonceCache,
}; };
/// Online users type /// Online users type
@ -92,7 +92,6 @@ pub async fn user_connected(
depot: &Depot, depot: &Depot,
) -> Result<(), StatusError> { ) -> Result<(), StatusError> {
let nonce_cache = depot.nonce_cache(); let nonce_cache = depot.nonce_cache();
let nonce_limit = *depot.nonce_cache_size();
let public_key = let public_key =
utils::extract_public_key(req).expect("The public key was checked in the middleware"); utils::extract_public_key(req).expect("The public key was checked in the middleware");
// FIXME: The config should hold `K256Secret` not `PrivateKey` // FIXME: The config should hold `K256Secret` not `PrivateKey`
@ -101,7 +100,7 @@ pub async fn user_connected(
WebSocketUpgrade::new() WebSocketUpgrade::new()
.upgrade(req, res, move |ws| { .upgrade(req, res, move |ws| {
handle_socket(ws, nonce_cache, nonce_limit, public_key, shared_secret) handle_socket(ws, nonce_cache, public_key, shared_secret)
}) })
.await .await
} }
@ -110,7 +109,6 @@ pub async fn user_connected(
async fn handle_socket( async fn handle_socket(
ws: WebSocket, ws: WebSocket,
nonce_cache: Arc<NonceCache>, nonce_cache: Arc<NonceCache>,
nonce_limit: usize,
user_public_key: PublicKey, user_public_key: PublicKey,
user_shared_secret: [u8; 32], user_shared_secret: [u8; 32],
) { ) {
@ -131,7 +129,7 @@ async fn handle_socket(
let fut = async move { let fut = async move {
while let Some(Ok(msg)) = user_ws_receiver.next().await { while let Some(Ok(msg)) = user_ws_receiver.next().await {
match handle_ws_msg(msg, &nonce_cache, &nonce_limit, &user_shared_secret) { match handle_ws_msg(msg, &nonce_cache, &user_shared_secret).await {
Ok(event) => { Ok(event) => {
if let Some(server_event) = handle_events(event, &conn_id).await { if let Some(server_event) = handle_events(event, &conn_id).await {
if let Err(err) = sender.unbounded_send(Ok(Message::from( if let Err(err) = sender.unbounded_send(Ok(Message::from(
@ -158,10 +156,9 @@ async fn handle_socket(
} }
/// Handle websocket msg /// Handle websocket msg
fn handle_ws_msg( async fn handle_ws_msg(
msg: Message, msg: Message,
nonce_cache: &NonceCache, nonce_cache: &NonceCache,
nonce_limit: &usize,
shared_secret: &[u8; 32], shared_secret: &[u8; 32],
) -> WsResult<ClientEvent> { ) -> WsResult<ClientEvent> {
let Ok(text) = msg.to_str() else { let Ok(text) = msg.to_str() else {
@ -174,7 +171,7 @@ fn handle_ws_msg(
WsError::InvalidJsonData WsError::InvalidJsonData
} }
})?; })?;
if !event.verify_signature(shared_secret, nonce_cache, nonce_limit) { if !event.verify_signature(shared_secret, nonce_cache).await {
return Err(WsError::InvalidSignature); return Err(WsError::InvalidSignature);
} }
Ok(event) Ok(event)