Files
hotel_api/src/utils/auth.rs
Romain Mallard e5a1d36654
Some checks failed
Deploy API / build-and-deploy (push) Failing after 5s
rust fmt and some cleaning
2026-03-11 14:22:46 +01:00

1181 lines
34 KiB
Rust
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
Json,
body::{Body, to_bytes},
extract::{Extension, FromRequest, FromRequestParts, Path, State, ws::close_code::STATUS},
http::{
Request as HttpRequest, StatusCode,
header::{HeaderValue, SET_COOKIE},
request::Parts,
status,
},
middleware::Next,
response::{IntoResponse, IntoResponseParts, Response},
};
use std::time::Duration;
use axum_extra::extract::TypedHeader;
//use axum_extra::TypedHeader;
use futures_util::future::TrySelect;
use headers::{Cookie, UserAgent};
use axum::extract::FromRef;
use axum::extract::Request as ExtractRequest;
use chrono::{Utc, format};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use reqwest::header::REFRESH;
use rusqlite::{Connection, OptionalExtension, params};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use argon2::{
Argon2,
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
};
use rand_core::{OsRng, RngCore};
use uuid::Uuid;
//use crate::utils::db_pool::;
use crate::utils::db_pool::{AppState, HotelPool};
use base64::{Engine as _, engine::general_purpose};
#[derive(Clone)]
pub struct JwtKeys {
pub encoding: EncodingKey,
pub decoding: DecodingKey,
}
pub async fn token_tester(
State(state): State<AppState>,
//Extension(keys): Extension<JwtKeys>,
AuthClaims { user_id, hotel_id }: AuthClaims,
) -> impl IntoResponse {
format!("(user_id: {}) from hotel {}", user_id, hotel_id)
}
pub struct AuthUser(pub Claims); //??
#[derive(Debug, Clone)]
pub struct AuthClaims {
pub user_id: i32,
pub hotel_id: i32,
//pub username: String,
}
pub fn auth_claims_from_token(
token: &str,
keys: &JwtKeys,
) -> Result<AuthClaims, (StatusCode, String)> {
let token_data = decode::<Claims>(token, &keys.decoding, &Validation::new(Algorithm::HS256))
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid token".into()))?;
Ok(AuthClaims {
user_id: token_data.claims.id,
hotel_id: token_data.claims.hotel_id,
})
}
impl<S> FromRequestParts<S> for AuthClaims
where
S: Send + Sync + 'static,
AppState: FromRef<S>,
{
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Extension(keys): Extension<JwtKeys> = Extension::from_request_parts(parts, state)
.await
.map_err(|_| (StatusCode::UNAUTHORIZED, "Missing keys".into()))?;
let auth_header = parts
.headers
.get(axum::http::header::AUTHORIZATION)
.ok_or((
StatusCode::UNAUTHORIZED,
"Missing Authorization header".into(),
))?
.to_str()
.map_err(|_| {
(
StatusCode::BAD_REQUEST,
"Invalid Authorization header".into(),
)
})?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::BAD_REQUEST, "Expected Bearer token".into()))?;
auth_claims_from_token(token, &keys)
}
}
// Hash a new password
fn hash_password(password: &str) -> anyhow::Result<String> {
let salt = SaltString::generate(&mut OsRng); // unique per password
let argon2 = Argon2::default(); // Argon2id with good defaults
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!(e))?
.to_string();
Ok(password_hash)
}
// Verify an incoming password against stored hash
fn verify_password(password: &str, stored_hash: &str) -> bool {
let parsed_hash = match PasswordHash::new(&stored_hash) {
Ok(hash) => hash,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
}
#[derive(Deserialize, Debug)]
pub struct RegisterValues {
username: String,
password: String,
#[serde(default)]
hotel_ids: Vec<i32>, //-> :Vec!<32>, maybe optionnal ?
displayname: String,
}
pub struct RegisterPayload(pub RegisterValues);
impl<S> FromRequest<S> for RegisterPayload
where
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
let Json(payload) = Json::<RegisterValues>::from_request(req, state)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
Ok(RegisterPayload(payload))
}
}
//TODO: Validate all hotel_ids first + Use a transaction + Batch query hotel names with IN (...)
pub async fn register_user(
State(state): State<AppState>,
RegisterPayload(payload): RegisterPayload,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let hashed_password = hash_password(&payload.password).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Password hashing failed: {}", e),
)
})?;
let conn = state.logs_pool.get().map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB connection error: {}", e),
)
})?;
conn.execute(
"INSERT INTO users (username, password, displayname)
VALUES (?1, ?2, ?3)",
params![payload.username, hashed_password, payload.displayname],
)
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("User insert error: {}", e),
)
})?;
let user_id = conn.last_insert_rowid();
for &hotel_id in &payload.hotel_ids {
// more logic for security here
//FIXME: needs to be the display name in the DB, scheme is currently wrong
let hotel_name: String = conn
.query_row(
"SELECT hotelname FROM hotels
WHERE id = ?1 ",
params![hotel_id],
|row| row.get(0),
)
.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid hotel id {}: {}", hotel_id, e),
)
})?;
conn.execute(
"INSERT INTO hotel_user_link (user_id, hotel_id, username, hotelname)
VALUES (?1, ?2, ?3, ?4)",
params![user_id, hotel_id, payload.username, hotel_name],
)
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!(
"Link insert error for user_id={} hotel_id={}: {}",
user_id, hotel_id, e
),
)
})?;
}
Ok((
StatusCode::CREATED,
"User registered successfully".to_string(),
))
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ForceUpdatePasswordValues {
username: String,
newpassword: String,
hotel_id: i32,
admin_pass: String,
}
//pub struct ForceUpdatePasswordPayload (pub ForceUpdatePasswordValues);
pub async fn force_update_password(
State(state): State<AppState>,
Json(payload): Json<ForceUpdatePasswordValues>,
) -> impl IntoResponse {
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response(),
};
let user_row = match conn
.query_row(
"SELECT id FROM users WHERE username = ?1 AND hotel_id = ?2",
params![&payload.username, &payload.hotel_id],
|row| {
let user_id: i32 = row.get(0)?;
//let hotel_id: i32 = row.get(1)?;
Ok(user_id)
},
)
.optional()
{
Ok(opt) => opt,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
};
let user_id = match user_row {
Some(u) => u,
None => return (StatusCode::UNAUTHORIZED, "Not correct user").into_response(),
};
let admin_check: String = "my_admin_password".to_string();
if &payload.admin_pass != &admin_check {
return (StatusCode::UNAUTHORIZED, "Invalid Amin Password").into_response();
};
let hashed_password = match hash_password(&payload.newpassword) {
Ok(h) => h,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response();
}
};
let result = conn.execute(
"UPDATE users SET password = ?1 WHERE id = ?2",
params![&hashed_password, &user_id],
);
match result {
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to update password",
)
.into_response(),
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UpdatePasswordValues {
username: String,
current_password: String,
newpassword: String,
//hotel_id: i32,
}
pub async fn update_password(
State(state): State<AppState>,
Json(payload): Json<UpdatePasswordValues>,
) -> impl IntoResponse {
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB, conn failed").into_response(),
};
let user_row = match conn
.query_row(
"SELECT password, id FROM users WHERE username = ?1",
params![&payload.username],
|row| {
let password: String = row.get(0)?;
let id: i32 = row.get(1)?;
Ok((password, id))
},
)
.optional()
{
Ok(opt) => opt,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB query error: {}", e),
)
.into_response();
}
};
let (password, user_id) = match user_row {
Some(u) => u,
None => return (StatusCode::UNAUTHORIZED, "Not correct user").into_response(),
};
if !verify_password(&payload.current_password, &password) {
return (StatusCode::UNAUTHORIZED, "Invalid Password").into_response();
};
let hashed_password = match hash_password(&payload.newpassword) {
Ok(h) => h,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "Password hashing failed").into_response();
}
};
let result = conn.execute(
"UPDATE users SET password = ?1 WHERE id = ?2",
params![&hashed_password, &user_id],
);
match result {
Ok(rows) if rows > 0 => (StatusCode::OK, "Password updated").into_response(),
Ok(_) => (StatusCode::NOT_FOUND, "User not found").into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to update password",
)
.into_response(),
}
}
#[derive(Deserialize, Debug)]
pub struct LoginValues {
username: String,
password: String,
//hotel_id: i32,
}
pub struct LoginPayload(pub LoginValues);
impl<S> FromRequest<S> for LoginPayload
where
S: Send + Sync,
{
type Rejection = (StatusCode, String);
async fn from_request(req: ExtractRequest, state: &S) -> Result<Self, Self::Rejection> {
let Json(payload) = Json::<LoginValues>::from_request(req, state)
.await
.map_err(|err| (StatusCode::BAD_REQUEST, format!("Invalid body: {}", err)))?;
Ok(LoginPayload(payload))
}
}
#[derive(Deserialize, Debug, Serialize, Clone)]
struct Claims {
id: i32,
hotel_id: i32,
//display_name
//username: String,
exp: usize,
}
#[derive(Serialize)]
struct LoginResponse {
token: String,
}
#[derive(Serialize)]
struct MultiLoginResponse {
user_id: i32,
tokens: Vec<String>,
}
pub async fn clean_auth_loging(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
LoginPayload(payload): LoginPayload,
) -> impl IntoResponse {
// 1⃣ Get a connection from logs pool
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
}
};
let user_row = match conn
.query_row(
"SELECT id, password, hotel_id, displayname FROM users WHERE username = ?1",
params![&payload.username],
|row| {
let user_id: i32 = row.get(0)?;
let password: String = row.get(1)?;
let hotel_id: i32 = row.get(2)?;
let displayname: String = row.get(3)?;
Ok((user_id, password, hotel_id, displayname))
},
)
.optional()
{
Ok(opt) => opt,
Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, "DB query error").into_response(),
};
let (user_id, stored_hash, hotel_id, _displayname) = match user_row {
Some(u) => u,
None => return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response(),
};
if !verify_password(&payload.password, &stored_hash) {
return (StatusCode::UNAUTHORIZED, "Invalid credentials").into_response();
}
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
Some(time) => time.timestamp() as usize,
None => {
// Handle overflow — probably a 500, since this should never happen
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Time overflow".to_string(),
)
.into_response();
}
};
let claims = serde_json::json!({
"id": user_id,
"hotel_id": hotel_id,
//"username": payload.username,
"exp": expiration
});
let token = match encode(&Header::default(), &claims, &keys.encoding) {
Ok(t) => t,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response();
}
};
Json(LoginResponse { token }).into_response()
}
#[derive(Deserialize, Debug)]
pub struct CreateRefreshTokenValue {
pub username: String,
pub password: String,
pub device_id: Uuid,
//pub timestamp: Option<String>,
}
//FIXME: weird return type, returning result ?
#[axum::debug_handler]
pub async fn create_refresh_token(
State(state): State<AppState>,
user_agent: Option<TypedHeader<UserAgent>>,
Json(payload): Json<CreateRefreshTokenValue>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// ← Add Result here
let user_agent_str = user_agent
.map(|ua| ua.to_string())
.unwrap_or_else(|| "Unknown".to_string());
let device_id_str = payload.device_id.to_string();
let conn = state.logs_pool.get().map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"DB connection error".to_string(),
)
})?;
let argon2 = Argon2::default();
let salt = SaltString::generate(&mut OsRng);
let mut bytes = [0u8; 64];
OsRng.fill_bytes(&mut bytes);
let raw_token = Uuid::new_v4().to_string();
let hashed_token = &raw_token;
/*
let hashed_token = argon2
.hash_password(raw_token.as_bytes(), &salt)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.to_string();
*/
// let mut stmt = conn.prepare(
// "SELECT id, password FROM users WHERE username = ?1"
let credentials = match conn.query_row(
"SELECT id, password FROM users WHERE username = ?1",
params![&payload.username],
|row| {
let user_id: i32 = row.get(0)?;
let password: String = row.get(1)?;
Ok((user_id, password))
},
) {
Ok(cr) => cr,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("error fetching credentials: {e}").to_string(),
));
}
};
let (user_id, user_password) = credentials;
/*
let (user_id, stored_hash, hotel_id) = user_row
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
*/
//let mut tokens = Vec::new();
//TODO: validate password
if !verify_password(&payload.password, &user_password) {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid credential".to_string(),
)); // Skip rows with invalid password
}
//TODO: get hotel name to return a map/tuple of hotel name
let mut stmt = match conn.prepare("SELECt hotel_id FROM hotel_user_link WHERE user_id = ?1") {
Ok(stmt) => stmt,
Err(_) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"error building user_id fetch stmt".to_string(),
));
}
};
//TODO: compiler les hotel id dans un vecteur pour le feed dans le refresh token
let hotel_ids: Vec<i32> = match stmt.query_map(params![&user_id], |row| row.get(0)) {
Ok(rows) => rows.collect::<Result<Vec<_>, _>>().map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Error collecting hotel_ids".to_string(),
)
})?,
Err(_) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Error mapping hotel_ids".to_string(),
));
}
};
let hotel_ids_json = match serde_json::to_string(&hotel_ids) {
Ok(json) => json,
Err(_) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Error mapping hotel_ids".to_string(),
));
}
};
/*.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Error mapping hotel_ids".to_string())); */
//FIXME: might not need the hotel list on tconflict ?
//TODO: remove user agent entirely from auth ,it is mutable and not stable
//TODO: make the token refresh on login
conn.execute(
r#"
INSERT INTO refresh_token (
user_id,
token_hash,
device_id,
hotel_id_list
)
VALUES (?1, ?2, ?3, ?4)
ON CONFLICT(user_id, device_id)
DO UPDATE SET
token_hash = excluded.token_hash,
hotel_id_list = excluded.hotel_id_list
"#,
params![user_id, hashed_token, device_id_str, hotel_ids_json],
)
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB error: {}", e),
)
})?;
//TODO: add a map/tupple of of the allowed hotels and their id+name, maybe update the token ?
println!("RAW write refresh_token bytes: {:?}", &raw_token.as_bytes());
println!("RAW refresh_token : {}", &raw_token.to_string());
println!("RAW write refresh_token len: {}", &raw_token.len());
let cookie_value = format!(
"refresh_token={}; HttpOnly; Secure; Max-Age=60480000000;Path=/",
raw_token
);
let mut response = (StatusCode::CREATED, "Refresh token created successfully").into_response();
response
.headers_mut()
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
Ok(response) // ← Wrap in Ok()
}
#[derive(Deserialize)]
pub struct LoginRefreshTokenValues {
device_id: Uuid,
//refresh_token: String,
}
//TODO: LATER : implement hotel-id-selected to allow user to only get part hotels ?
pub async fn login_refresh_token(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
user_agent: Option<TypedHeader<UserAgent>>,
cookie_header: Option<TypedHeader<headers::Cookie>>,
Json(payload): Json<LoginRefreshTokenValues>,
) -> impl IntoResponse {
println!("login_refresh_token called");
// Log cookies
let cookies = match cookie_header {
Some(token) => token,
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
};
let refresh_token = match cookies.get("refresh_token") {
Some(token) => token.to_string(),
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
};
println!("RAW refresh_token bytes: {:?}", refresh_token.as_bytes());
println!("RAW refresh_token : {}", refresh_token.to_string());
println!("RAW refresh_token len: {}", refresh_token.len());
println!("Cookies: {:?}", &refresh_token);
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
}
};
let user_agent_str = match user_agent {
Some(ua) => ua.to_string(),
None => return (StatusCode::INTERNAL_SERVER_ERROR, "user agent unknown").into_response(),
};
println!("UA {:?}", &user_agent_str);
let device_id_str = payload.device_id.to_string();
println!("device id: {:?}", &device_id_str);
//"SELECT user_id, token_hash, hotel_id FROM refresh_token WHERE device_id = ?1 AND user_agent = ?2",
//TODO: swap to query row and get hotel-id's list and not single hotel per row
//deserializing the list :
//let hotel_ids: Vec<i32> = serde_json::from_str(&stored_value)?;
let mut stmt = match conn.prepare(
"SELECT user_id, hotel_id_list
FROM refresh_token
WHERE device_id = ?1 AND token_hash = ?2
LIMIT 1;",
) {
Ok(s) => s,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Error prepatring hotel_id_list stmt",
)
.into_response();
}
};
let rows = match stmt
.query_one(params![&device_id_str, &refresh_token], |row| {
Ok((
row.get::<_, i32>(0)?, // user_id
row.get::<_, String>(1)?, // token_hash
//row.get::<_, String>(2)?, // hotel_id //FIXME: this is supposed to be vectore maybe ?
))
})
.optional()
{
Ok(r) => r,
Err(e) => {
eprintln!("DB ERROR: {:?}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB query error: {}", e),
)
.into_response();
}
};
//TODO: extraction of the blob
//let json_hotel_ids = rows.2;
let (user_id, json_hotel_ids) = match rows {
Some(r) => r,
None => {
return (
StatusCode::UNAUTHORIZED,
"No refresh token found for this device",
)
.into_response();
}
};
let hotel_ids: Vec<i32> = match serde_json::from_str(&json_hotel_ids) {
Ok(ids) => ids,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Hotel ids are not deserializable to Vec",
)
.into_response();
}
};
//FIXME: still problems when corrupted token exist
if hotel_ids.is_empty() {
return (StatusCode::UNAUTHORIZED, "No matching device").into_response();
}
/*
eprintln!("DB ERROR: {:?}", &refresh_token);
eprintln!("DB ERROR: {:?}", &token);
//still not auto adding hotel user link when creating account
if (&refresh_token != &token) {
// skip rows with wrong hash
return (StatusCode::UNAUTHORIZED, "Invelid credentials").into_response();
}
*/
let expiration = match chrono::Utc::now().checked_add_signed(chrono::Duration::hours(15)) {
Some(time) => time.timestamp() as usize,
None => {
// Handle overflow — probably a 500, since this should never happen
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Time overflow".to_string(),
)
.into_response();
}
};
let mut tokens = Vec::new();
for hotel_id in hotel_ids {
let claims = serde_json::json!({
"id": user_id,
"hotel_id": hotel_id,
"exp": expiration
});
let token = match encode(&Header::default(), &claims, &keys.encoding) {
Ok(token) => token,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "JWT creation failed").into_response();
}
};
tokens.push(token);
}
if tokens.is_empty() {
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
}
//Json(tokens).into_response()
Json(MultiLoginResponse { user_id, tokens }).into_response()
}
#[axum::debug_handler]
pub async fn logout_from_single_device(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
user_agent: Option<TypedHeader<UserAgent>>,
cookie_header: Option<TypedHeader<headers::Cookie>>,
Json(payload): Json<LoginRefreshTokenValues>,
) -> impl IntoResponse {
let user_agent_str = user_agent
.map(|TypedHeader(ua)| ua.as_str().to_owned())
.unwrap_or_else(|| "Unknown".to_string());
let cookies = match cookie_header {
Some(token) => token,
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
};
let refresh_token = match cookies.get("refresh_token") {
Some(token) => token.to_string(),
None => return (StatusCode::UNAUTHORIZED, "Missing refresh token cookie").into_response(),
};
let device_id_str = payload.device_id.to_string();
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
}
};
let device_row = match conn
.query_row(
"SELECT user_id, hotel_id_list, id
FROM refresh_token
WHERE token_hash = ?1 AND revoked = 0 ",
params![&refresh_token],
|row| {
let user_id: i32 = row.get(0)?;
let json_hotel_id_list: String = row.get(1)?;
let id: i32 = row.get(2)?;
//let displayname: String = row.get(3)?;
Ok((user_id, json_hotel_id_list, id))
},
)
.optional()
{
Ok(opt) => opt,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB query error : {}", e),
)
.into_response();
}
};
let (user_id, json_hotel_id_list, token_id) = match device_row {
Some(tuple) => tuple,
None => return (StatusCode::UNAUTHORIZED, "No matching device").into_response(),
};
let hotel_ids: Vec<i32> = match serde_json::from_str(&json_hotel_id_list) {
Ok(ids) => ids,
Err(_) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Hotel ids are not deserializable to Vec",
)
.into_response();
}
};
//FIXME: need to chang the way we get refresh token from the cookies instead
/*
if !verify_password(&payload.refresh_token, &token_hash) {
return (StatusCode::UNAUTHORIZED, "Invalid or mismatched token").into_response();
}
*/
let revoked: Result<String, rusqlite::Error> = conn.query_row(
"DELETE FROM refresh_token
WHERE id = ?1
RETURNING device_id",
params![&token_id],
|row| row.get(0),
);
let revoked_id = match (revoked) {
Ok(r) => r,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"Hotel ids are not deserializable to Vec",
)
.into_response();
}
};
let cookie_value = format!(
"refresh_token={}; HttpOnly; Secure; Max-Age=0;Path=/",
"loggedout"
);
let mut response = (
StatusCode::CREATED,
format!("Token deleted for device id {}", &revoked_id),
)
.into_response();
response
.headers_mut()
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
response
}
pub async fn logout_from_all_devices(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
AuthClaims { user_id, hotel_id }: AuthClaims,
//Json(payload): Json<LoginRefreshTokenValues>
) -> impl IntoResponse {
//let device_id_str = payload.device_id.to_string();
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(_) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
}
};
let result = conn.execute(
"DELETE FROM refresh_token WHERE user_id = ?1",
params![&user_id],
);
/*
match result {
//Ok(count) if count > 0 => {
// (StatusCode::OK, format!("Revoked {} active tokens", count)).into_response()
//}
//Ok(_) => (StatusCode::NOT_FOUND, "No active tokens to revoke").into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
"Database update error".to_string(),
)
.into_response(),
}
*/
let cookie_value = format!(
"refresh_token={}; HttpOnly; Secure; Max-Age=0;Path=/",
"loggedout"
);
let mut response =
(StatusCode::CREATED, format!("Token deleted for device id ")).into_response();
response
.headers_mut()
.insert(SET_COOKIE, HeaderValue::from_str(&cookie_value).unwrap());
match result {
//Ok(count) if count > 0 => {
// (StatusCode::OK, format!("Revoked {} active tokens", count)).into_response()
//}
Ok(_) => response,
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
err.to_string(), // or format!("{err:?}")
)
.into_response(),
}
//response
}
#[derive(Serialize)]
struct HotelData {
id: i32,
hotel_name: String,
}
pub async fn get_hotel(State(state): State<AppState>) -> impl IntoResponse {
let try_conn = state.logs_pool.get();
let conn = match try_conn {
Ok(conn) => conn,
Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, "bruh").into_response(),
};
let try_stmt = conn.prepare(
"
SELECT id, hotelname
FROM hotels",
);
let mut stmt = match try_stmt {
Ok(stmt) => stmt,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"failed buildin statement",
)
.into_response();
}
};
let try_hotels = stmt.query_map(params![], |row| {
Ok(HotelData {
id: row.get(0)?,
hotel_name: row.get(1)?,
})
});
let hotel_itter = match try_hotels {
Ok(hotels) => hotels,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
"error processing hotel list",
)
.into_response();
}
};
let hotels: Vec<HotelData> = match hotel_itter.collect::<Result<Vec<_>, _>>() {
Ok(hotel) => hotel,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed collection of hotel : {e}"),
)
.into_response();
}
};
match serde_json::to_string(&hotels) {
Ok(json) => return (StatusCode::OK, json).into_response(),
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Serialization failed: {}", e),
)
.into_response();
}
};
//.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error".to_string()))?;
//return (StatusCode::OK).into_response();
}
#[derive(Deserialize, Debug)]
pub struct addHotelUser {
user_id: i32,
#[serde(default)]
hotel_ids: Vec<i32>,
}
pub async fn add_hotel_user(
State(state): State<AppState>,
Extension(keys): Extension<JwtKeys>,
Json(payload): Json<addHotelUser>,
) -> impl IntoResponse {
let conn = match state.logs_pool.get() {
Ok(c) => c,
Err(e) => {
return (StatusCode::INTERNAL_SERVER_ERROR, "DB connection error").into_response();
}
};
let user_name: String = match conn.query_row(
"SELECT username FROM users WHERE id = ?1",
params![&payload.user_id],
|row| row.get(0),
) {
Ok(name) => name,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("user not found {e} "),
)
.into_response();
}
};
let mut get_hotel_name_stmt = match conn.prepare("SELECT hotelname FROM hotels WHERE id = ?1") {
Ok(stmt) => stmt,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("could't prepare stmt for hotel : {e} "),
)
.into_response();
}
};
let mut insert_hotel_link_stmt = match conn.prepare(
"INSERT INTO hotel_user_link
(user_id,hotel_id,username,hotelname)
VALUES (?1,?2,?3,?4)",
) {
Ok(stmt) => stmt,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("could't prepare stmt to insert hotel : {e} "),
)
.into_response();
}
};
for &hotel_id in &payload.hotel_ids {
let hotel_name: String =
match get_hotel_name_stmt.query_row(params![hotel_id], |row| row.get(0)) {
Ok(name) => name,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("hotel not found {e} "),
)
.into_response();
}
};
let add_link = match conn.execute(
"INSERT INTO hotel_user_link
(user_id,hotel_id,username,hotelname)
VALUES (?1,?2,?3,?4)",
params![payload.user_id, hotel_id, user_name, hotel_name],
) {
Ok(_) => true,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("hotel not found {e} "),
)
.into_response();
}
};
//TODO: still need to build the add hotel to user here
}
return (StatusCode::OK, "goo").into_response();
}
fn internal_error<E: std::fmt::Display>(err: E) -> (StatusCode, String) {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Internal error: {}", err),
)
}