From a76200bb7adb6189d84e0e98d6233a470ebeee98 Mon Sep 17 00:00:00 2001 From: Santo Cariotti Date: Wed, 21 Aug 2024 17:34:58 +0200 Subject: Authentication for endpoints --- src/errors.rs | 5 +---- src/graphql/query.rs | 2 +- src/graphql/routes.rs | 9 +++++--- src/graphql/types/jwt.rs | 55 +++++++++++++++++++++++++++++++++++++++++------ src/graphql/types/user.rs | 41 ++++++++++++++++++++--------------- src/main.rs | 4 +++- src/state.rs | 1 + 7 files changed, 84 insertions(+), 33 deletions(-) (limited to 'src') diff --git a/src/errors.rs b/src/errors.rs index fafe0b0..1b9a802 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -42,10 +42,7 @@ impl IntoResponse for AppError { "Token creation error".to_string(), ), AppError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token".to_string()), - AppError::Unauthorized => ( - StatusCode::UNAUTHORIZED, - "Can't perform this action".to_string(), - ), + AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()), }; let body = Json(json!({ diff --git a/src/graphql/query.rs b/src/graphql/query.rs index ed83a9f..254dab6 100644 --- a/src/graphql/query.rs +++ b/src/graphql/query.rs @@ -1,4 +1,4 @@ -use crate::graphql::types::user; +use crate::{errors::AppError, graphql::types::user}; use async_graphql::{Context, Object}; pub struct Query; diff --git a/src/graphql/routes.rs b/src/graphql/routes.rs index e15267e..a566c65 100644 --- a/src/graphql/routes.rs +++ b/src/graphql/routes.rs @@ -2,11 +2,14 @@ use crate::graphql::mutation::Mutation; use crate::graphql::query::Query; use async_graphql::{EmptySubscription, Schema}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; -use std::sync::Arc; +use axum::extract::Extension; + +use super::types::jwt::Authentication; pub async fn graphql_handler( - schema: Arc>, + schema: Extension>, + auth: Authentication, req: GraphQLRequest, ) -> GraphQLResponse { - schema.execute(req.into_inner()).await.into() + schema.execute(req.0.data(auth)).await.into() } diff --git a/src/graphql/types/jwt.rs b/src/graphql/types/jwt.rs index 932f7fd..c118622 100644 --- a/src/graphql/types/jwt.rs +++ b/src/graphql/types/jwt.rs @@ -1,5 +1,10 @@ use crate::errors::AppError; use async_graphql::{InputObject, SimpleObject}; +use axum::{async_trait, extract::FromRequestParts, http::request::Parts}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + typed_header::TypedHeader, +}; use chrono::{Duration, Local}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use once_cell::sync::Lazy; @@ -7,7 +12,7 @@ use serde::{Deserialize, Serialize}; struct Keys { encoding: EncodingKey, - _decoding: DecodingKey, + decoding: DecodingKey, } static KEYS: Lazy = Lazy::new(|| { @@ -19,20 +24,25 @@ impl Keys { fn new(secret: &[u8]) -> Self { Self { encoding: EncodingKey::from_secret(secret), - _decoding: DecodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), } } } -/// Claims struct -#[derive(Serialize, Deserialize)] +/// Claims struct. +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct Claims { - /// ID from the user model - pub user_id: i32, - /// Expiration timestamp + user_id: i32, exp: usize, } +/// Authentication enum +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum Authentication { + Logged(Claims), + NotLogged, +} + impl Claims { /// Create a new Claim using the `user_id` and the current timestamp + 2 days pub fn new(user_id: i32) -> Self { @@ -77,3 +87,34 @@ impl AuthBody { } } } + +#[async_trait] +impl FromRequestParts for Authentication +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + // Extract the Authorization header + match TypedHeader::>::from_request_parts(parts, &()).await { + Ok(TypedHeader(Authorization(bearer))) => { + // Decode the token + let token_data = + decode::(bearer.token(), &KEYS.decoding, &Validation::default()) + .map_err(|err| match err.kind() { + jsonwebtoken::errors::ErrorKind::ExpiredSignature => { + AppError::InvalidToken + } + _ => { + eprintln!("{err:?}"); + return AppError::Unauthorized; + } + })?; + + Ok(Self::Logged(token_data.claims)) + } + Err(_) => Ok(Self::NotLogged), + } + } +} diff --git a/src/graphql/types/user.rs b/src/graphql/types/user.rs index bf9080f..b675f1f 100644 --- a/src/graphql/types/user.rs +++ b/src/graphql/types/user.rs @@ -2,6 +2,8 @@ use crate::state::AppState; use async_graphql::{Context, Object}; use serde::{Deserialize, Serialize}; +use super::jwt::Authentication; + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct User { pub id: i32, @@ -32,21 +34,26 @@ impl User { pub async fn get_users<'ctx>(ctx: &Context<'ctx>) -> Result>, String> { let state = ctx.data::().expect("Can't connect to db"); let client = &*state.client; - - let rows = client - .query("SELECT id, email, password, is_admin FROM users", &[]) - .await - .unwrap(); - - let users: Vec = rows - .iter() - .map(|row| User { - id: row.get("id"), - email: row.get("email"), - password: row.get("password"), - is_admin: row.get("is_admin"), - }) - .collect(); - - Ok(Some(users)) + let auth: &Authentication = ctx.data().unwrap(); + match auth { + Authentication::NotLogged => Err("Unauthorized".to_string()), + Authentication::Logged(_claims) => { + let rows = client + .query("SELECT id, email, password, is_admin FROM users", &[]) + .await + .unwrap(); + + let users: Vec = rows + .iter() + .map(|row| User { + id: row.get("id"), + email: row.get("email"), + password: row.get("password"), + is_admin: row.get("is_admin"), + }) + .collect(); + + Ok(Some(users)) + } + } } diff --git a/src/main.rs b/src/main.rs index 04e2564..8dfc145 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,7 @@ use tracing::Span; async fn create_app() -> Router { logger::setup(); let dbclient = db::setup().await.unwrap(); + let state = state::AppState { client: Arc::new(dbclient), }; @@ -37,10 +38,11 @@ async fn create_app() -> Router { ) .data(state.clone()) .finish(); + Router::new() .route( "/graphql", - post(move |req| graphql::routes::graphql_handler(schema.clone().into(), req)), + post(graphql::routes::graphql_handler).layer(Extension(schema.clone())), ) .fallback(crate::routes::page_404) // Mark the `Authorization` request header as sensitive so it doesn't diff --git a/src/state.rs b/src/state.rs index d4719b9..876a5f4 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use tokio_postgres::Client; + #[derive(Clone)] pub struct AppState { pub client: Arc, -- cgit v1.2.3-18-g5258