README + Clean up
All checks were successful
Deploy API / build-and-deploy (push) Successful in 23s

This commit is contained in:
2026-05-01 12:12:55 +02:00
parent e8b6a392a1
commit 4043f9d032
14 changed files with 165 additions and 305 deletions

106
README.md
View File

@@ -1 +1,107 @@
# hotel-api-rs
##Description
Ce projet est un proof of concept servant de terrain d'expérimentation pour un outil interne destiné à des hotels.
Objectif:
-Petite échelle: ~50 hôtels maximum, ~5 utilisateurs par hôtel
-Facilité de maintenance, y compris pour des utilisateurs peu à l'aise avec Linux
Ce projet m'a persmis de découvrir et mettre en pratique :
-Rust (API backend)
-Authentification JWT
-Hash de mots de passe(Argon2)
-WebSockets (temps réel)
-Docker
-GitActions
-SQlite
##Fonctionalités
-CRUD (rooms, inventory, chat)
-Authentification :
-Refresh token
-Session token
-Notifications en temps réel(WebSocket)
-Chat en temps réel
##Démo
Une interface de démonstration est disponible ici:
https://mallardromain.com/hotel-demo/
##Installation
###Local
```
cargo run
```
-Les bases SQLite sont présentes par défaut dans `/db`
-Port par défaut : 7080(configuré dans `main.rs`)
###Docker
```
docker run -d \
-e JWT_SECRET=secret_JWT \
-v "/votre/chemin/:/app/db" \
```
Notes:
-Il est nécessaire de copier les bases de donner par défaut dans le volume
-Le scret doit obligatoirement être préciser a la création du container
-Valeur de JWT_SECRET pour les utilisateur présent sur les bases de donnée par défault:
`your_jwt_secret_key`
##Architecture
###Routing
-Point d'entrée: `./src/routes/mod.rs` (utilié dans le `./main.rs`)
-Organisation par domaine:
-chaque module posséde son propre `routes.rs`
###Modules principaux
-`.src/rooms`,`.src/inventory` et `.src/chat` sont principalement de la logique CRUD
###Utils
-`src/utils/db_pools.rs`
Getsion des connexions aux bases SQLite de chaque hôtels
-`.src/utils/websocket.rs`
Implémentation des WebSockets:
-Notification temps réel
-chat (émission/reception)
-`.src/utils/auth.rs`
-Hash/vérification des mots de passe(Argon2)
-Génération et validation des JWT
-Pré-traitement des tokens
##Authentification
Flow:
1. Obtenir un "refresh token"
`GET /auth/get_refresh`
2. Obtenir un session token
`POST /auth/login_refresh_token`
3. Ouvrir une connexion WebSocket
`/auth/ws/YourToken`
##Ce que j'ai appris
-Utilisation du framework "Axum"
-Gestion de payload des requète explicite et stricte
-Gestion des websocket plus complexe
-Dockerisation simple
-CD simple (déploiment automatique sur un VPS linux)

View File

@@ -7,7 +7,6 @@ use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct CreateConversationValues {
//pub creator_id: i32, // already in token ?
pub name: String,
}

View File

@@ -297,17 +297,14 @@ pub async fn send_message(
match statement.exists(params![user_id, payload.conv_id]) {
Ok(true) => {
// user is part of the conversation — continue
}
Ok(false) => {
// early exit: not part of the conversation
return (
StatusCode::FORBIDDEN,
"Not part of the conversation".to_string(),
);
}
Err(_) => {
// query failed
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Query failed".to_string(),
@@ -330,10 +327,6 @@ pub async fn send_message(
);
}
};
//let message_id = conn.last_insert_rowid();
// FIXME: add sent_at and message id in the response.
// --- send to conversation participants ---
let mut stmt_participants = conn
.prepare("SELECT user_id FROM conversation_participants WHERE conversation_id = ?1")
.expect("prepare participants failed");
@@ -370,7 +363,6 @@ pub async fn send_message(
),
)
}
//Ok(_) => (StatusCode::NOT_FOUND, "Conversation not found".to_string()),
#[derive(Debug, Serialize)]
struct Message {
@@ -514,7 +506,6 @@ struct Conversation {
title: String,
}
//FIXME: allow null conv name ? default to persons name
pub async fn get_convs(
State(state): State<AppState>,
//Path((item_name, item_amount)): Path<(String, i32)>,
@@ -553,7 +544,6 @@ pub async fn get_convs(
})
}) {
Ok(rows) => rows,
//Ok(_) => {}, IMPLEMENT NO CONV ?
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
@@ -571,9 +561,6 @@ pub async fn get_convs(
);
}
};
//.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, Json("error".to_string())));
match serde_json::to_string(&convs) {
Ok(json) => (StatusCode::OK, json),
Err(e) => (
@@ -582,69 +569,3 @@ pub async fn get_convs(
),
}
}
/*
pub async fn get_convs(
State(state): State<AppState>,
//Path((item_name, item_amount)): Path<(String, i32)>,
AuthClaims{ user_id, hotel_id}: AuthClaims,
) -> impl IntoResponse {
let pool = state.hotel_pools.get_pool(hotel_id);
let conn = match pool.get(){
Ok(conn) => conn,
Err(err) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Pool error: {}", err).into_response() )
};
let mut stmt = match conn.prepare(
"SELECT id, title FROM conversation WHERE creator_id = ?1",
) {
Ok(s) => s,
Err(e) =>
return (StatusCode::INTERNAL_SERVER_ERROR, format!("Prepare failed: {}", e).into_response() )
};
let rows = match stmt.query_map(params![user_id], |row| {
let id: i32 = row.get(0)?;
let title: String = row.get(1)?;
Ok((title, id))
}) {
Ok(rows) => rows,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Query failed: {}", e).into_response() )
};
let mut map = HashMap::new();
// ✅ Iterate through the row results
for row_result in rows {
match row_result {
Ok((title, id)) => {
map.insert(title, id);
}
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("Row parsing failed: {}", e).into_response() )
}
}
let conv_map_json = match to_value(map) {
Ok(c) => c,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, format!("List unwrapping failed: {}", e).into_response() )
};
let conv_map_clean_json = serde_json::to_value(map)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Serialization failed: {}", e).into_response() ));
(StatusCode::OK, Json(conv_map_clean_json)).into_response()
}
*/

View File

@@ -1,4 +1,3 @@
pub mod routes;
mod extractor;
mod handlers;

View File

@@ -60,9 +60,6 @@ pub async fn update_inventory_item(
Path((item_id, item_amount)): Path<(i32, i32)>,
AuthClaims { user_id, hotel_id }: AuthClaims,
) -> impl IntoResponse {
//TODO: make better error handling :
// if wrong param collumn targeted,
// if missing path param
let pool = state.hotel_pools.get_pool(hotel_id);
@@ -119,14 +116,6 @@ pub async fn update_inventory_item(
format!("Error from DB: {err}"),
),
}
/*
match result {
Ok(row) => (StatusCode::OK, format!("Items updated")),
Ok(_) => (StatusCode::NOT_FOUND, format!("No item with this id exist")),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, format!("error updating the item with id :{} with amount: {}", item_id, item_amount))
}
*/
}
pub async fn get_inventory_item(
@@ -169,7 +158,6 @@ pub async fn get_inventory_item(
items.push(item);
}
// Serialize to JSON
let json = match serde_json::to_string(&items) {
Ok(j) => j,
Err(_) => {

View File

@@ -1,3 +1,2 @@
mod handler;
pub mod routes;

View File

@@ -1,37 +1,38 @@
use axum::Extension;
use dotenvy::dotenv;
use std::env;
use axum::{Extension,serve};
use axum::extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
};
use axum::serve;
use jsonwebtoken::{DecodingKey, EncodingKey};
use reqwest::header::AUTHORIZATION;
use reqwest::header::CONTENT_TYPE;
use reqwest::header::USER_AGENT;
use axum::http::{HeaderValue, Method};
use jsonwebtoken::{DecodingKey, EncodingKey};
use reqwest::header::{
AUTHORIZATION,
CONTENT_TYPE,
USER_AGENT};
use tokio::net::TcpListener;
use tower_http::cors::{Any, CorsLayer};
use dashmap::DashMap;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use std::sync::Arc;
use crate::routes::create_router;
use crate::utils::auth::JwtKeys;
use crate::utils::db_pool::{AppState, HotelPool};
use routes::create_router;
use dotenvy::dotenv;
use std::env;
//use tower_http::cors::Origin;
use axum::http::{HeaderValue, Method};
use tower_http::cors::{Any, CorsLayer};
mod chat;
mod inventory;
mod rooms;
mod routes;
mod utils;
use dashmap::DashMap;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use std::sync::Arc;
//Send notification on discord through webhook
pub async fn notify_discord(msg: &str) -> Result<(), reqwest::Error> {
let payload = serde_json::json!({
"content": msg
@@ -50,14 +51,12 @@ pub async fn notify_discord(msg: &str) -> Result<(), reqwest::Error> {
async fn main() -> std::io::Result<()> {
dotenv().ok();
//send crahs notification t odiscord webhook
std::panic::set_hook(Box::new(|info| {
let msg = format!("Rust panic: {}", info);
// Use blocking client so the process can't exit before sending
let payload = serde_json::json!({
"content": msg
});
let client = reqwest::blocking::Client::new();
let _ = client
.post("https://discord.com/api/webhooks/1440912618205347891/Ekg89krDoPm41kA27LA3gXgNWmMWvCCtziYIUsjqaY22Jnw4a6IWhZOht0in5JjnPX-W")
@@ -65,24 +64,20 @@ async fn main() -> std::io::Result<()> {
.send();
}));
//panic!("crash-test");
let hotel_pools = HotelPool::new();
//DB pools
let logs_manager = SqliteConnectionManager::file("db/auth_copy_2.sqlite");
let logs_pool = Pool::builder()
.max_size(5)
.build(logs_manager)
.expect("Failed to build logs pool");
let hotel_pools = HotelPool::new();
let state = AppState {
hotel_pools,
logs_pool,
ws_map: Arc::new(DashMap::new()),
//jwt_secret: "your_jwt_secret_key s".to_string(), // better: load from env var
};
//let jwt_secret = "your_jwt_secret_key".to_string();
let jwt_secret = env::var("JWT_SECRET")
.expect("JWT_SECRET must be set")
.to_string();
@@ -92,12 +87,12 @@ async fn main() -> std::io::Result<()> {
decoding: DecodingKey::from_secret(jwt_secret.as_ref()),
};
let allowed_origins = vec!["http://82.66.253.209", "http://localhost:5173"];
let cors = CorsLayer::very_permissive()
.allow_credentials(true)
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::OPTIONS])
.allow_headers([CONTENT_TYPE, AUTHORIZATION, USER_AGENT]);
//create router using entry point from "src/routes/mod.rs"
let app = create_router(state).layer(Extension(jwt_keys)).layer(cors);
let listener = TcpListener::bind("0.0.0.0:7080").await?;
@@ -105,6 +100,3 @@ async fn main() -> std::io::Result<()> {
Ok(())
}
async fn handler() -> &'static str {
"Hiii from localhost"
}

View File

@@ -61,10 +61,6 @@ pub async fn clean_db_update(
AuthClaims { user_id, hotel_id }: AuthClaims,
UpdateRoomPayload(payload): UpdateRoomPayload,
) -> impl IntoResponse {
//TODO: make better error handling :
// if wrong param collumn targeted,
// if missing path param
let pool = state.hotel_pools.get_pool(hotel_id);
let conn = match pool.get() {

View File

@@ -1,6 +1,3 @@
//pub mod handler;
//pub mod routes;
mod extractor;
mod handler;
pub mod routes;

View File

@@ -2,11 +2,9 @@ use axum::{
Router,
routing::{get, put},
};
use crate::rooms::handler::*;
use crate::utils::db_pool::{AppState, HotelPool};
// ROOTS
pub fn rooms_routes() -> Router<AppState> {
Router::new()
.route("/", get(hello_rooms))

View File

@@ -6,13 +6,8 @@ use crate::chat::routes::chat_routes;
use crate::inventory::routes::inventory_routes;
use crate::rooms::routes::rooms_routes;
use crate::utils::routes::utils_routes;
use crate::utils::db_pool::AppState;
//TODO: add secret fomr dotenv here
/*
Function to build our main router
that regroup all feature centered router
*/
pub fn create_router(state: AppState) -> Router {
Router::new()
.nest("/auth", utils_routes().with_state(state.clone()))

View File

@@ -1,45 +1,40 @@
use axum::{
Json,
body::{Body, to_bytes},
extract::{Extension, FromRequest, FromRequestParts, Path, State, ws::close_code::STATUS},
middleware::Next,
extract::{Extension, Path, State,
FromRequest, FromRequestParts, FromRef, Request as ExtractRequest,
ws::close_code::STATUS
},
http::{
Request as HttpRequest, StatusCode,
header::{HeaderValue, SET_COOKIE},
request::Parts,
status,
},
middleware::Next,
body::{Body, to_bytes},
response::{IntoResponse, IntoResponseParts, Response},
};
use std::time::Duration;
use axum_extra::extract::TypedHeader;
//use axum_extra::TypedHeader;
use headers::{UserAgent};
use futures_util::future::TrySelect;
use headers::{Cookie, UserAgent};
use axum::extract::FromRef;
use axum::extract::Request as ExtractRequest;
use chrono::{Utc, format};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use reqwest::header::REFRESH;
use rusqlite::{Connection, OptionalExtension, params};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
};
use std::time::Duration;
use chrono::{Utc, format};
use rand_core::{OsRng, RngCore};
use uuid::Uuid;
use serde::{Deserialize, Serialize};
use base64::{Engine as _, engine::general_purpose};
use rusqlite::{Connection, OptionalExtension, params};
//use crate::utils::db_pool::;
use crate::utils::db_pool::{AppState, HotelPool};
use base64::{Engine as _, engine::general_purpose};
#[derive(Clone)]
pub struct JwtKeys {
@@ -49,19 +44,16 @@ pub struct JwtKeys {
pub async fn token_tester(
State(state): State<AppState>,
//Extension(keys): Extension<JwtKeys>,
AuthClaims { user_id, hotel_id }: AuthClaims,
) -> impl IntoResponse {
format!("(user_id: {}) from hotel {}", user_id, hotel_id)
}
pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)]
pub struct AuthClaims {
pub user_id: i32,
pub hotel_id: i32,
//pub username: String,
}
pub fn auth_claims_from_token(
@@ -114,18 +106,16 @@ where
// Hash a new password
fn hash_password(password: &str) -> anyhow::Result<String> {
let salt = SaltString::generate(&mut OsRng); // unique per password
let argon2 = Argon2::default(); // Argon2id with good defaults
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!(e))?
.to_string();
Ok(password_hash)
}
// Verify an incoming password against stored hash
fn verify_password(password: &str, stored_hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(&stored_hash) {
Ok(hash) => hash,
@@ -142,7 +132,7 @@ pub struct RegisterValues {
username: String,
password: String,
#[serde(default)]
hotel_ids: Vec<i32>, //-> :Vec!<32>, maybe optionnal ?
hotel_ids: Vec<i32>,
displayname: String,
}
@@ -163,7 +153,6 @@ where
}
//TODO: Validate all hotel_ids first + Use a transaction + Batch query hotel names with IN (...)
pub async fn register_user(
State(state): State<AppState>,
RegisterPayload(payload): RegisterPayload,
@@ -197,7 +186,6 @@ pub async fn register_user(
let user_id = conn.last_insert_rowid();
for &hotel_id in &payload.hotel_ids {
// more logic for security here
//FIXME: needs to be the display name in the DB, scheme is currently wrong
let hotel_name: String = conn
@@ -244,8 +232,6 @@ pub struct ForceUpdatePasswordValues {
admin_pass: String,
}
//pub struct ForceUpdatePasswordPayload (pub ForceUpdatePasswordValues);
pub async fn force_update_password(
State(state): State<AppState>,
Json(payload): Json<ForceUpdatePasswordValues>,
@@ -310,7 +296,6 @@ pub struct UpdatePasswordValues {
username: String,
current_password: String,
newpassword: String,
//hotel_id: i32,
}
pub async fn update_password(
@@ -425,7 +410,7 @@ pub async fn clean_auth_loging(
Extension(keys): Extension<JwtKeys>,
LoginPayload(payload): LoginPayload,
) -> impl IntoResponse {
// 1⃣ Get a connection from logs pool
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => {
@@ -475,7 +460,6 @@ pub async fn clean_auth_loging(
let claims = serde_json::json!({
"id": user_id,
"hotel_id": hotel_id,
//"username": payload.username,
"exp": expiration
});
@@ -493,18 +477,14 @@ pub struct CreateRefreshTokenValue {
pub username: String,
pub password: String,
pub device_id: Uuid,
//pub timestamp: Option<String>,
}
//FIXME: weird return type, returning result ?
#[axum::debug_handler]
pub async fn create_refresh_token(
State(state): State<AppState>,
user_agent: Option<TypedHeader<UserAgent>>,
Json(payload): Json<CreateRefreshTokenValue>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// ← Add Result here
let user_agent_str = user_agent
.map(|ua| ua.to_string())
@@ -526,16 +506,6 @@ pub async fn create_refresh_token(
let raw_token = Uuid::new_v4().to_string();
let hashed_token = &raw_token;
/*
let hashed_token = argon2
.hash_password(raw_token.as_bytes(), &salt)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.to_string();
*/
// let mut stmt = conn.prepare(
// "SELECT id, password FROM users WHERE username = ?1"
let credentials = match conn.query_row(
"SELECT id, password FROM users WHERE username = ?1",
params![&payload.username],
@@ -556,20 +526,13 @@ pub async fn create_refresh_token(
let (user_id, user_password) = credentials;
/*
let (user_id, stored_hash, hotel_id) = user_row
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
*/
//let mut tokens = Vec::new();
//TODO: validate password
if !verify_password(&payload.password, &user_password) {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid credential".to_string(),
)); // Skip rows with invalid password
));
}
//TODO: get hotel name to return a map/tuple of hotel name
let mut stmt = match conn.prepare("SELECt hotel_id FROM hotel_user_link WHERE user_id = ?1") {
Ok(stmt) => stmt,
Err(_) => {
@@ -580,7 +543,6 @@ pub async fn create_refresh_token(
}
};
//TODO: compiler les hotel id dans un vecteur pour le feed dans le refresh token
let hotel_ids: Vec<i32> = match stmt.query_map(params![&user_id], |row| row.get(0)) {
Ok(rows) => rows.collect::<Result<Vec<_>, _>>().map_err(|_| {
(
@@ -606,12 +568,6 @@ pub async fn create_refresh_token(
}
};
/*.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Error mapping hotel_ids".to_string())); */
//FIXME: might not need the hotel list on tconflict ?
//TODO: remove user agent entirely from auth ,it is mutable and not stable
//TODO: make the token refresh on login
conn.execute(
r#"
INSERT INTO refresh_token (
@@ -635,8 +591,6 @@ pub async fn create_refresh_token(
)
})?;
//TODO: add a map/tupple of of the allowed hotels and their id+name, maybe update the token ?
println!("RAW write refresh_token bytes: {:?}", &raw_token.as_bytes());
println!("RAW refresh_token : {}", &raw_token.to_string());
println!("RAW write refresh_token len: {}", &raw_token.len());
@@ -657,10 +611,8 @@ pub async fn create_refresh_token(
#[derive(Deserialize)]
pub struct LoginRefreshTokenValues {
device_id: Uuid,
//refresh_token: String,
}
//TODO: LATER : implement hotel-id-selected to allow user to only get part hotels ?
pub async fn login_refresh_token(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
@@ -669,7 +621,6 @@ pub async fn login_refresh_token(
Json(payload): Json<LoginRefreshTokenValues>,
) -> impl IntoResponse {
println!("login_refresh_token called");
// Log cookies
let cookies = match cookie_header {
Some(token) => token,
@@ -684,7 +635,6 @@ pub async fn login_refresh_token(
println!("RAW refresh_token bytes: {:?}", refresh_token.as_bytes());
println!("RAW refresh_token : {}", refresh_token.to_string());
println!("RAW refresh_token len: {}", refresh_token.len());
println!("Cookies: {:?}", &refresh_token);
let conn = match state.logs_pool.get() {
@@ -700,14 +650,8 @@ pub async fn login_refresh_token(
};
println!("UA {:?}", &user_agent_str);
let device_id_str = payload.device_id.to_string();
println!("device id: {:?}", &device_id_str);
//"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
//TODO: swap to query row and get hotel-id's list and not single hotel per row
//deserializing the list :
//let hotel_ids: Vec<i32> = serde_json::from_str(&stored_value)?;
let mut stmt = match conn.prepare(
"SELECT user_id, hotel_id_list
FROM refresh_token
@@ -727,9 +671,8 @@ pub async fn login_refresh_token(
let rows = match stmt
.query_one(params![&device_id_str, &refresh_token], |row| {
Ok((
row.get::<_, i32>(0)?, // user_id
row.get::<_, String>(1)?, // token_hash
//row.get::<_, String>(2)?, // hotel_id //FIXME: this is supposed to be vectore maybe ?
row.get::<_, i32>(0)?,
row.get::<_, String>(1)?,
))
})
.optional()
@@ -744,8 +687,7 @@ pub async fn login_refresh_token(
.into_response();
}
};
//TODO: extraction of the blob
//let json_hotel_ids = rows.2;
let (user_id, json_hotel_ids) = match rows {
Some(r) => r,
None => {
@@ -756,6 +698,7 @@ pub async fn login_refresh_token(
.into_response();
}
};
let hotel_ids: Vec<i32> = match serde_json::from_str(&json_hotel_ids) {
Ok(ids) => ids,
Err(_) => {
@@ -767,23 +710,10 @@ pub async fn login_refresh_token(
}
};
//FIXME: still problems when corrupted token exist
if hotel_ids.is_empty() {
return (StatusCode::UNAUTHORIZED, "No matching device").into_response();
}
/*
eprintln!("DB ERROR: {:?}", &refresh_token);
eprintln!("DB ERROR: {:?}", &token);
//still not auto adding hotel user link when creating account
if (&refresh_token != &token) {
// skip rows with wrong hash
return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response();
}
*/
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
Some(time) => time.timestamp() as usize,
None => {
@@ -819,7 +749,6 @@ pub async fn login_refresh_token(
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
}
//Json(tokens).into_response()
Json(MultiLoginResponse { user_id, tokens }).into_response()
}
@@ -854,9 +783,8 @@ pub async fn logout_from_single_device(
}
};
let device_row = match conn
.query_row(
"SELECT user_id, hotel_id_list, id
let device_row = match conn.query_row(
"SELECT user_id, hotel_id_list, id
FROM refresh_token
WHERE token_hash = ?1 AND revoked = 0 ",
params![&refresh_token],
@@ -867,16 +795,14 @@ pub async fn logout_from_single_device(
//let displayname: String = row.get(3)?;
Ok((user_id, json_hotel_id_list, id))
},
)
.optional()
).optional()
{
Ok(opt) => opt,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB query error : {}", e),
)
.into_response();
).into_response();
}
};
@@ -891,18 +817,10 @@ pub async fn logout_from_single_device(
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Hotel ids are not deserializable to Vec",
)
.into_response();
).into_response();
}
};
//FIXME: need to chang the way we get refresh token from the cookies instead
/*
if !verify_password(&payload.refresh_token, &token_hash) {
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
}
*/
let revoked: Result<String, rusqlite::Error> = conn.query_row(
"DELETE FROM refresh_token
WHERE id = ?1
@@ -911,14 +829,13 @@ pub async fn logout_from_single_device(
|row| row.get(0),
);
let revoked_id = match (revoked) {
let revoked_id = match revoked {
Ok(r) => r,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Hotel ids are not deserializable to Vec",
)
.into_response();
).into_response();
}
};
@@ -930,11 +847,9 @@ pub async fn logout_from_single_device(
let mut response = (
StatusCode::CREATED,
format!("Token deleted for device id {}", &revoked_id),
)
.into_response();
).into_response();
response
.headers_mut()
response.headers_mut()
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
response
@@ -944,9 +859,7 @@ pub async fn logout_from_all_devices(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
AuthClaims { user_id, hotel_id }: AuthClaims,
//Json(payload): Json<LoginRefreshTokenValues>
) -> impl IntoResponse {
//let device_id_str = payload.device_id.to_string();
let conn = match state.logs_pool.get() {
Ok(c) => c,
@@ -960,21 +873,6 @@ pub async fn logout_from_all_devices(
params![&user_id],
);
/*
match result {
//Ok(count) if count > 0 => {
// (StatusCode::OK, format!("Revoked {} active tokens", count)).into_response()
//}
//Ok(_) => (StatusCode::NOT_FOUND, "No active tokens to revoke").into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Database update error".to_string(),
)
.into_response(),
}
*/
let cookie_value = format!(
"refresh_token={}; HttpOnly; Secure; Max-Age=0;Path=/",
"loggedout"
@@ -988,18 +886,13 @@ pub async fn logout_from_all_devices(
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
match result {
//Ok(count) if count > 0 => {
// (StatusCode::OK, format!("Revoked {} active tokens", count)).into_response()
//}
Ok(_) => response,
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
err.to_string(), // or format!("{err:?}")
err.to_string(),
)
.into_response(),
}
//response
}
#[derive(Serialize)]
@@ -1072,9 +965,6 @@ pub async fn get_hotel(State(state): State<AppState>) -> impl IntoResponse {
.into_response();
}
};
//.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error".to_string()))?;
//return (StatusCode::OK).into_response();
}
#[derive(Deserialize, Debug)]
@@ -1165,8 +1055,6 @@ pub async fn add_hotel_user(
.into_response();
}
};
//TODO: still need to build the add hotel to user here
}
return (StatusCode::OK, "goo").into_response();

View File

@@ -17,7 +17,7 @@ pub struct AppState {
pub ws_map: WsMap,
}
type HotelId = i32; // or i32 if you want numeric ids
type HotelId = i32;
#[derive(Clone)]
pub struct HotelPool {

View File

@@ -28,10 +28,8 @@ pub type WsMap = Arc<HotelMap>;
/// Type alias: user_id → sender to that user
async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, user_id: i32) {
// channel for sending messages TO this client
let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
// insert into hotel → user map
let user_map = state
.ws_map
.entry(hotel_id)
@@ -39,14 +37,10 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, us
.clone();
user_map.insert(user_id, tx);
// ✅ print after upgrading
print_ws_state(&state);
// split socket into sender/receiver
let (mut sender, mut receiver) = socket.split();
// task for sending messages from server to client
let mut rx_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sender.send(msg).await.is_err() {
@@ -55,7 +49,6 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, us
}
});
// task for receiving messages from client
let state_clone = state.clone();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
@@ -75,13 +68,11 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, us
}
});
// wait for either side to finish
tokio::select! {
_ = (&mut rx_task) => recv_task.abort(),
_ = (&mut recv_task) => rx_task.abort(),
}
// cleanup
user_map.remove(&user_id);
if user_map.is_empty() {
state.ws_map.remove(&hotel_id);
@@ -89,7 +80,6 @@ async fn handle_socket(mut socket: WebSocket, state: AppState, hotel_id: i32, us
}
pub async fn ws_handler(
//AuthClaims {user_id, hotel_id}: AuthClaims,
ws: WebSocketUpgrade,
Extension(keys): Extension<JwtKeys>,
State(state): State<AppState>,
@@ -106,14 +96,6 @@ pub async fn ws_handler(
};
print!("{token}, web socket tried to connect",);
/*
let claims = match auth_claims_from_token(&token, &keys) {
Ok(c) => c,
Err(_) => return StatusCode::UNAUTHORIZED.into_response(),
};
*/
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.hotel_id, claims.user_id))
}