diff --git a/Cargo.lock b/Cargo.lock index 8cbfea3..79369cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -409,7 +409,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] @@ -672,6 +672,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "tower-http", "uuid", ] @@ -1469,7 +1470,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] @@ -1764,7 +1765,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys 0.59.0", + "windows-sys 0.61.1", ] [[package]] @@ -1922,9 +1923,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" dependencies = [ "bitflags", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 7654ff2..07244a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ futures-util = {version = "0.3.31"} uuid = {version = "1.18.1", features = ["serde"] } base64 = "0.22.1" reqwest = { version = "0.12.24", features = ["json","blocking"] } +tower-http = { version = "0.6.7", features = ["cors"] } diff --git a/db/1.sqlite b/db/1.sqlite index 369beb4..0a33887 100644 Binary files a/db/1.sqlite and b/db/1.sqlite differ diff --git a/db/1.sqlite-shm b/db/1.sqlite-shm deleted file mode 100644 index fe9ac28..0000000 Binary files a/db/1.sqlite-shm and /dev/null differ diff --git a/db/1.sqlite-wal b/db/1.sqlite-wal deleted file mode 100644 index e69de29..0000000 diff --git a/db/auth_copy_2.sqlite b/db/auth_copy_2.sqlite index 80bbf23..761aa8a 100644 Binary files a/db/auth_copy_2.sqlite and b/db/auth_copy_2.sqlite differ diff --git a/db/auth_copy_2.sqlite-shm b/db/auth_copy_2.sqlite-shm deleted file mode 100644 index 3564af8..0000000 Binary files a/db/auth_copy_2.sqlite-shm and /dev/null differ diff --git a/db/auth_copy_2.sqlite-wal b/db/auth_copy_2.sqlite-wal deleted file mode 100644 index 3f04b9d..0000000 Binary files a/db/auth_copy_2.sqlite-wal and /dev/null differ diff --git a/src/main.rs b/src/main.rs index 3a391c6..d8afa8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,8 @@ use axum::serve; use axum::Extension; use axum::extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}; use jsonwebtoken::{DecodingKey, EncodingKey}; +use reqwest::header::AUTHORIZATION; +use reqwest::header::CONTENT_TYPE; use tokio::net::TcpListener; use tokio::sync::mpsc; @@ -24,6 +26,8 @@ use crate::utils::auth::JwtKeys; use std::env; use dotenvy::dotenv; +use tower_http::cors::{CorsLayer, Any}; +use axum::http::{Method, HeaderValue}; pub async fn notify_discord(msg: &str) -> Result<(), reqwest::Error> { let payload = serde_json::json!({ @@ -88,10 +92,16 @@ std::panic::set_hook(Box::new(|info| { }; + let cors = CorsLayer::new() + .allow_origin("http://localhost:5173".parse::().unwrap()) + .allow_credentials(true) + .allow_methods([Method::GET, Method::POST, Method::PUT , Method::OPTIONS]) + .allow_headers([CONTENT_TYPE, AUTHORIZATION]); let app = create_router(state) - .layer(Extension(jwt_keys)); - + .layer(Extension(jwt_keys)) + .layer(cors); + let listener = TcpListener::bind("0.0.0.0:7080").await?; serve(listener, app).into_future().await?; Ok(()) diff --git a/src/utils/auth.rs b/src/utils/auth.rs index cef5cba..4f73c27 100644 --- a/src/utils/auth.rs +++ b/src/utils/auth.rs @@ -6,14 +6,15 @@ use axum::{ use axum_extra::extract::TypedHeader; //use axum_extra::TypedHeader; -use headers::UserAgent; +use headers::{UserAgent, Cookie}; use axum::extract::FromRef; use axum::extract::Request as ExtractRequest; use jsonwebtoken::{decode, DecodingKey, Validation, encode, EncodingKey, Header, Algorithm}; +use reqwest::header::REFRESH; use serde::{Deserialize, Serialize}; use serde_json::Value; -use chrono::{Utc}; +use chrono::{Utc, format}; use rusqlite::{params, Connection, OptionalExtension}; use rand_core::{RngCore, OsRng}; @@ -468,16 +469,15 @@ pub async fn create_refresh_token( - //TODO: get hotel name + //TODO: get hotel name to return a map/tuple of hotel name let mut stmt = match conn.prepare( "SELECt hotel_id FROM hotel_user_link WHERE user_id = ?1" ) { Ok(stmt) => stmt, Err(_) => return Err((StatusCode::INTERNAL_SERVER_ERROR, "error building user_id fetch stmt".to_string())), }; - //TODO: compiler les hotel id dans un vecteur pour le feed dans le refresh token - //Deja fait ? + //TODO: compiler les hotel id dans un vecteur pour le feed dans le refresh token let hotel_ids: Vec = match stmt .query_map(params![&user_id],|row| row.get (0)) { @@ -539,7 +539,9 @@ pub async fn create_refresh_token( */ - let cookie_value = format!("refresh_token={}; HttpOnly; Secure; Path=/", raw_token); + //TODO: add a map/tupple of of the allowed hotels and their id+name, maybe update the token ? + + let cookie_value = format!("refresh_token={}; HttpOnly; Secure; Max-Age=60480000000;Path=/", raw_token); let mut response = (StatusCode::CREATED, "Refresh token created successfully").into_response(); response.headers_mut().insert( @@ -553,7 +555,7 @@ pub async fn create_refresh_token( #[derive(Deserialize)] pub struct LoginRefreshTokenValues{ device_id: Uuid, - refresh_token: String, + //refresh_token: String, } //TODO: LATER : implement hotel-id-selected to allow user to only get part hotels ? @@ -561,29 +563,39 @@ pub async fn login_refresh_token ( State(state): State, Extension(keys): Extension, user_agent: Option>, + cookie_header: Option>, Json(payload): Json ) -> impl IntoResponse { + println!("login_refresh_token called"); + // Log cookies + + + let cookies = match cookie_header { + Some(token) => token, + None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(), + }; + + let refresh_token = match cookies.get("refresh_token") { + Some(token) => token.to_string(), + None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(), + }; + + println!("Cookies: {:?}", &refresh_token); + let conn = match state.logs_pool.get() { Ok(c) => c, Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response(), }; - - /*let user_agent_str = user_agent - .map(|ua| ua.to_string()) - .unwrap_or_else(|| "Unknown".to_string()); - - let device_id_str = payload.device_id.to_string(); - */ - let user_agent_str = match user_agent { Some(ua) => ua.to_string(), None => return (StatusCode::INTERNAL_SERVER_ERROR, "user agent unknown").into_response(), }; - + println!("UA {:?}", &user_agent_str); let device_id_str = payload.device_id.to_string(); + println!("device id: {:?}", &device_id_str); //"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2", @@ -593,10 +605,11 @@ pub async fn login_refresh_token ( let mut stmt = match conn.prepare( "SELECT user_id, token_hash, hotel_id_list FROM refresh_token - WHERE device_id = ?1 AND user_agent = ?2 " + WHERE device_id = ?1 AND user_agent = ?2 + LIMIT 1;" ) { Ok(s) => s, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(), + Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "Error prepatring hotel_id_list stmt").into_response(), }; let rows = match stmt.query_one(params![&device_id_str, &user_agent_str], |row| { @@ -607,7 +620,10 @@ pub async fn login_refresh_token ( )) }) { Ok(r) => r, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(), + Err(e) => { + eprintln!("DB ERROR: {:?}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, format!("DB query error: {}", e)).into_response() + } }; //TODO: extraction of the blob //let json_hotel_ids = rows.2; @@ -619,21 +635,11 @@ pub async fn login_refresh_token ( }; -/* - let mut entries = Vec::new(); - for r in rows { - match r { - Ok(t) => entries.push(t), - Err(_) => continue, // ignore corrupt rows - } - } - */ - if hotel_ids.is_empty() { return (StatusCode::UNAUTHORIZED, "No matching device").into_response(); } - if !verify_password(&payload.refresh_token, &saved_hash) { + if !verify_password(&refresh_token, &saved_hash) { // skip rows with wrong hash return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response(); } @@ -663,31 +669,6 @@ pub async fn login_refresh_token ( tokens.push(token); } - /* OLD ITERATION OVER MULTIPLE REFRESH TOKEN - // swap to "for hotel_id in entries" // interator over vector list - for (user_id, token_hash, hotel_id) in entries { - if !verify_password(&payload.refresh_token, &token_hash) { - // skip rows with wrong hash - continue; - } - //FIXME: single expiration - - - let claims = serde_json::json!({ - "id": user_id, - "hotel_id": hotel_id, - "exp": expiration - }); - - let token = match encode(&Header::default(), &claims, &keys.encoding) { - Ok(t) => t, - Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response(), - }; - - tokens.push(token); - } - */ - if tokens.is_empty() { return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response(); } @@ -736,10 +717,12 @@ pub async fn logout_from_single_device ( None => return (StatusCode::UNAUTHORIZED, "No matching device").into_response(), }; + //FIXME: need to chang the way we get refresh token from the cookies instead + /* if !verify_password(&payload.refresh_token, &token_hash) { return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response(); } - +*/ let revoked: Result = conn.query_row( "UPDATE refresh_token SET revoked = 1 WHERE id = ?1 RETURNING device_id", params![&token_id],