diff options
| author | Santo Cariotti <santo@dcariotti.me> | 2022-09-01 16:45:04 +0000 |
|---|---|---|
| committer | Santo Cariotti <santo@dcariotti.me> | 2022-09-01 16:45:04 +0000 |
| commit | ab23761e090b8ab6311a360eada7131f6663a3bf (patch) | |
| tree | b5a99bb4cfc811e45fc2e3680b4f8b1e944515eb /src | |
Fork from m6-ie project
Diffstat (limited to 'src')
| -rw-r--r-- | src/db.rs | 29 | ||||
| -rw-r--r-- | src/errors.rs | 63 | ||||
| -rw-r--r-- | src/logger.rs | 12 | ||||
| -rw-r--r-- | src/main.rs | 60 | ||||
| -rw-r--r-- | src/models/auth.rs | 99 | ||||
| -rw-r--r-- | src/models/mod.rs | 2 | ||||
| -rw-r--r-- | src/models/user.rs | 118 | ||||
| -rw-r--r-- | src/routes/auth.rs | 25 | ||||
| -rw-r--r-- | src/routes/mod.rs | 2 | ||||
| -rw-r--r-- | src/routes/user.rs | 39 |
10 files changed, 449 insertions, 0 deletions
diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..43c3bd9 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,29 @@ +use crate::errors::AppError; + +use sqlx::postgres::PgPool; + +/// Static variable used to manage the database connection. Called with value = None raises a panic +/// error. +static mut CONNECTION: Option<PgPool> = None; + +/// Setup database connection. Get variable `DATABASE_URL` from the environment. Sqlx crate already +/// defines an error for environments without DATABASE_URL. +pub async fn setup() -> Result<(), AppError> { + let database_url = + std::env::var("DATABASE_URL").expect("Define `DATABASE_URL` environment variable."); + + unsafe { + CONNECTION = Some(PgPool::connect(&database_url).await?); + } + + Ok(()) +} + +/// Get connection. Raises an error if `setup()` has not been called yet. +/// Managing static `CONNECTION` is an unsafe operation. +pub unsafe fn get_client() -> &'static PgPool { + match &CONNECTION { + Some(client) => client, + None => panic!("Connection not established!"), + } +} diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..e541eda --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,63 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; + +/// All errors raised by the web app +pub enum AppError { + /// Generic error, never called yet + Generic, + /// Database error + Database, + /// Generic bad request. It is handled with a message value + BadRequest(String), + /// Not found error + NotFound, + /// Raised when a token is not good created + TokenCreation, + /// Raised when a passed token is not valid + InvalidToken, +} + +/// Use `AppError` as response for an endpoint +impl IntoResponse for AppError { + /// Matches `AppError` into a tuple of status and error message. + /// The response will be a JSON in the format of: + /// ```json + /// { "error": "<message>" } + /// ``` + fn into_response(self) -> Response { + let (status, error_message) = match self { + AppError::Generic => ( + StatusCode::INTERNAL_SERVER_ERROR, + "Generic error, can't find why".to_string(), + ), + AppError::Database => ( + StatusCode::INTERNAL_SERVER_ERROR, + "Error with database connection".to_string(), + ), + AppError::BadRequest(value) => (StatusCode::BAD_REQUEST, value), + AppError::NotFound => (StatusCode::NOT_FOUND, "Element not found".to_string()), + AppError::TokenCreation => ( + StatusCode::INTERNAL_SERVER_ERROR, + "Token creation error".to_string(), + ), + AppError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token".to_string()), + }; + + let body = Json(json!({ + "error": error_message, + })); + + (status, body).into_response() + } +} + +/// Transforms a `sqlx::Error` into a `AppError::Databse` error +impl From<sqlx::Error> for AppError { + fn from(_error: sqlx::Error) -> AppError { + AppError::Database + } +} diff --git a/src/logger.rs b/src/logger.rs new file mode 100644 index 0000000..718384a --- /dev/null +++ b/src/logger.rs @@ -0,0 +1,12 @@ +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +/// Setup tracing subscriber logger +pub fn setup() { + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::new( + std::env::var("RUST_LOG") + .unwrap_or_else(|_| "m6_ie_2022=debug,tower_http=debug".into()), + )) + .with(tracing_subscriber::fmt::layer()) + .init(); +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..508d6cd --- /dev/null +++ b/src/main.rs @@ -0,0 +1,60 @@ +mod db; +mod errors; +mod logger; +mod models; +mod routes; + +use axum::{ + http::{header, Request}, + Router, +}; +use std::time::Duration; +use tower_http::sensitive_headers::SetSensitiveHeadersLayer; +use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer}; +use tracing::Span; + +/// Main application, called by the execution of the software +#[tokio::main] +async fn main() { + let app = create_app().await; + + /// By default the server is bind at "127.0.0.1:3000" + let addr = std::env::var("ALLOWED_HOST").unwrap_or_else(|_| "127.0.0.1:3000".to_string()); + tracing::info!("Listening on {}", addr); + + axum::Server::bind(&addr.parse().unwrap()) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +/// Create the app: setup everything and returns a `Router` +async fn create_app() -> Router { + logger::setup(); + let _ = db::setup().await; + + let api_routes = Router::new() + .nest("/users", routes::user::create_route()) + .nest("/auth", routes::auth::create_route()); + + Router::new() + // Map all routes to `/v1/*` namespace + .nest("/v1", api_routes) + // Mark the `Authorization` request header as sensitive so it doesn't + // show in logs. + .layer(SetSensitiveHeadersLayer::new(std::iter::once( + header::AUTHORIZATION, + ))) + // Use a layer for `TraceLayer` + .layer( + TraceLayer::new_for_http() + .on_request(|request: &Request<_>, _span: &Span| { + tracing::info!("{} {}", request.method(), request.uri()); + }) + .on_failure( + |error: ServerErrorsFailureClass, latency: Duration, _span: &Span| { + tracing::error!("{} | {} s", error, latency.as_secs()); + }, + ), + ) +} diff --git a/src/models/auth.rs b/src/models/auth.rs new file mode 100644 index 0000000..8b8f61c --- /dev/null +++ b/src/models/auth.rs @@ -0,0 +1,99 @@ +use crate::errors::AppError; +use axum::{ + async_trait, + extract::{FromRequest, RequestParts, TypedHeader}, + headers::{authorization::Bearer, Authorization}, +}; +use chrono::{Duration, Local}; +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; + +struct Keys { + encoding: EncodingKey, + decoding: DecodingKey, +} + +/// Claims struct +#[derive(Serialize, Deserialize)] +pub struct Claims { + /// ID from the user model + user_id: i32, + /// Expiration timestamp + exp: usize, +} + +/// Body used as response to login +#[derive(Serialize)] +pub struct AuthBody { + /// Access token string + access_token: String, + /// "Bearer" string + token_type: String, +} + +static KEYS: Lazy<Keys> = Lazy::new(|| { + let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); + Keys::new(secret.as_bytes()) +}); + +impl Keys { + fn new(secret: &[u8]) -> Self { + Self { + encoding: EncodingKey::from_secret(secret), + decoding: DecodingKey::from_secret(secret), + } + } +} + +impl Claims { + /// Create a new Claim using the `user_id` and the current timestamp + 2 days + pub fn new(user_id: i32) -> Self { + let expiration = Local::now() + Duration::days(2); + + Self { + user_id, + exp: expiration.timestamp() as usize, + } + } + + /// Returns the token as a string. If a token is not encoded, raises an + /// `AppError::TokenCreation` + pub fn get_token(&self) -> Result<String, AppError> { + let token = encode(&Header::default(), &self, &KEYS.encoding) + .map_err(|_| AppError::TokenCreation)?; + + Ok(token) + } +} + +impl AuthBody { + pub fn new(access_token: String) -> Self { + Self { + access_token, + token_type: "Bearer".to_string(), + } + } +} + +/// Parse a request to get the Authorization header and then decode it checking its validation +#[async_trait] +impl<B> FromRequest<B> for Claims +where + B: Send, +{ + type Rejection = AppError; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + // Extract the token from the authorization header + let TypedHeader(Authorization(bearer)) = + TypedHeader::<Authorization<Bearer>>::from_request(req) + .await + .map_err(|_| AppError::InvalidToken)?; + // Decode the user data + let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default()) + .map_err(|_| AppError::InvalidToken)?; + + Ok(token_data.claims) + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs new file mode 100644 index 0000000..f9bae3d --- /dev/null +++ b/src/models/mod.rs @@ -0,0 +1,2 @@ +pub mod auth; +pub mod user; diff --git a/src/models/user.rs b/src/models/user.rs new file mode 100644 index 0000000..06cde0a --- /dev/null +++ b/src/models/user.rs @@ -0,0 +1,118 @@ +use crate::db::get_client; +use crate::errors::AppError; + +use serde::{Deserialize, Serialize}; +use validator::Validate; + +/// User model +#[derive(Deserialize, Serialize, Validate)] +pub struct User { + id: i32, + #[validate(length(min = 1, message = "Can not be empty"))] + email: String, + #[validate(length(min = 8, message = "Must be min 8 chars length"))] + password: String, + is_staff: Option<bool>, +} + +/// Response used to print a user (or a users list) +#[derive(Deserialize, Serialize)] +pub struct UserList { + // It is public because it used by `Claims` creation + pub id: i32, + email: String, + is_staff: Option<bool>, +} + +/// Payload used for user creation +#[derive(Deserialize)] +pub struct UserCreate { + pub email: String, + pub password: String, +} + +impl User { + /// By default an user has id = 0. It is not created yet + pub fn new(email: String, password: String) -> Self { + Self { + id: 0, + email, + password, + is_staff: Some(false), + } + } + + /// Create a new user from the model using a SHA256 crypted password + pub async fn create(user: User) -> Result<UserList, AppError> { + let pool = unsafe { get_client() }; + + user.validate() + .map_err(|error| AppError::BadRequest(error.to_string()))?; + + let crypted_password = sha256::digest(user.password); + + let rec = sqlx::query_as!( + UserList, + r#" + INSERT INTO users (email, password) + VALUES ( $1, $2 ) + RETURNING id, email, is_staff + "#, + user.email, + crypted_password + ) + .fetch_one(pool) + .await?; + + Ok(rec) + } + + /// Find a user using the model. It used for login + pub async fn find(user: User) -> Result<UserList, AppError> { + let pool = unsafe { get_client() }; + + let crypted_password = sha256::digest(user.password); + + let rec = sqlx::query_as!( + UserList, + r#" + SELECT id, email, is_staff FROM "users" + WHERE email = $1 AND password = $2 + "#, + user.email, + crypted_password + ) + .fetch_one(pool) + .await?; + + Ok(rec) + } + + /// Returns the user with id = `user_id` + pub async fn find_by_id(user_id: i32) -> Result<UserList, AppError> { + let pool = unsafe { get_client() }; + + let rec = sqlx::query_as!( + UserList, + r#" + SELECT id, email, is_staff FROM "users" + WHERE id = $1 + "#, + user_id + ) + .fetch_one(pool) + .await?; + + Ok(rec) + } + + /// List all users + pub async fn list() -> Result<Vec<UserList>, AppError> { + let pool = unsafe { get_client() }; + let rows = sqlx::query_as!(UserList, r#"SELECT id, email, is_staff FROM users"#) + .fetch_all(pool) + .await?; + + Ok(rows) + } +} diff --git a/src/routes/auth.rs b/src/routes/auth.rs new file mode 100644 index 0000000..37c41b2 --- /dev/null +++ b/src/routes/auth.rs @@ -0,0 +1,25 @@ +use crate::errors::AppError; +use crate::models::{ + auth::{AuthBody, Claims}, + user::{User, UserCreate}, +}; +use axum::{routing::post, Json, Router}; + +/// Create routes for `/v1/auth/` namespace +pub fn create_route() -> Router { + Router::new().route("/login", post(make_login)) +} + +/// Make login. Check if a user with the email and password passed in request body exists into the +/// database +async fn make_login(Json(payload): Json<UserCreate>) -> Result<Json<AuthBody>, AppError> { + let user = User::new(payload.email, payload.password); + match User::find(user).await { + Ok(user) => { + let claims = Claims::new(user.id); + let token = claims.get_token()?; + Ok(Json(AuthBody::new(token))) + } + Err(_) => Err(AppError::NotFound), + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..f9bae3d --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,2 @@ +pub mod auth; +pub mod user; diff --git a/src/routes/user.rs b/src/routes/user.rs new file mode 100644 index 0000000..d44df66 --- /dev/null +++ b/src/routes/user.rs @@ -0,0 +1,39 @@ +use crate::errors::AppError; +use crate::models::{ + auth::Claims, + user::{User, UserCreate, UserList}, +}; +use axum::{extract::Path, routing::get, Json, Router}; + +/// Create routes for `/v1/users/` namespace +pub fn create_route() -> Router { + Router::new() + .route("/", get(list_users).post(create_user)) + .route("/:id", get(get_user)) +} + +/// List users. Checks Authorization token +async fn list_users(_: Claims) -> Result<Json<Vec<UserList>>, AppError> { + let users = User::list().await?; + + Ok(Json(users)) +} + +/// Create an user. Checks Authorization token +async fn create_user( + Json(payload): Json<UserCreate>, + _: Claims, +) -> Result<Json<UserList>, AppError> { + let user = User::new(payload.email, payload.password); + let user_new = User::create(user).await?; + + Ok(Json(user_new)) +} + +/// Get an user with id = `user_id`. Checks Authorization token +async fn get_user(Path(user_id): Path<i32>, _: Claims) -> Result<Json<UserList>, AppError> { + match User::find_by_id(user_id).await { + Ok(user) => Ok(Json(user)), + Err(_) => Err(AppError::NotFound), + } +} |
