summaryrefslogtreecommitdiff
path: root/api
diff options
context:
space:
mode:
Diffstat (limited to 'api')
-rw-r--r--api/auth/auth.go48
-rw-r--r--api/auth/auth_test.go74
-rw-r--r--api/database/database.go32
-rw-r--r--api/database/models.go11
-rw-r--r--api/handlers/handlers.go68
-rw-r--r--api/middleware/middleware.go36
6 files changed, 269 insertions, 0 deletions
diff --git a/api/auth/auth.go b/api/auth/auth.go
new file mode 100644
index 0000000..23b4f53
--- /dev/null
+++ b/api/auth/auth.go
@@ -0,0 +1,48 @@
+package auth
+
+import (
+ "github.com/golang-jwt/jwt/v5"
+ "os"
+ "time"
+)
+
+var jwtKey = []byte(os.Getenv("JWT_SECRET"))
+
+type Claims struct {
+ UserID int `json:"user_id"`
+ jwt.RegisteredClaims
+}
+
+func GenerateJWT(userID int) (string, error) {
+ expirationTime := time.Now().Add(5 * time.Hour)
+ claims := &Claims{
+ UserID: userID,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(expirationTime),
+ },
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ tokenString, err := token.SignedString(jwtKey)
+ if err != nil {
+ return "", err
+ }
+ return tokenString, nil
+}
+
+func ValidateJWT(tokenString string) (*Claims, error) {
+ claims := &Claims{}
+ token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
+ return jwtKey, nil
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ if !token.Valid {
+ return nil, err
+ }
+
+ return claims, nil
+}
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go
new file mode 100644
index 0000000..50b6c9b
--- /dev/null
+++ b/api/auth/auth_test.go
@@ -0,0 +1,74 @@
+package auth
+
+import (
+ "os"
+ "testing"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGenerateAndValidateJWT(t *testing.T) {
+ // Set up the JWT secret for the test.
+ os.Setenv("JWT_SECRET", "testsecret")
+ jwtKey = []byte(os.Getenv("JWT_SECRET"))
+
+ userID := 123
+ tokenString, err := GenerateJWT(userID)
+ assert.NoError(t, err)
+ assert.NotEmpty(t, tokenString)
+
+ claims, err := ValidateJWT(tokenString)
+ assert.NoError(t, err)
+ assert.NotNil(t, claims)
+ assert.Equal(t, userID, claims.UserID)
+ assert.True(t, claims.ExpiresAt.After(time.Now()))
+}
+
+func TestValidateJWT_InvalidToken(t *testing.T) {
+ os.Setenv("JWT_SECRET", "testsecret")
+ jwtKey = []byte(os.Getenv("JWT_SECRET"))
+
+ _, err := ValidateJWT("invalid_token")
+ assert.Error(t, err)
+}
+
+func TestValidateJWT_ExpiredToken(t *testing.T) {
+ os.Setenv("JWT_SECRET", "testsecret")
+ jwtKey = []byte(os.Getenv("JWT_SECRET"))
+
+ // Create a token that has already expired.
+ claims := &Claims{
+ UserID: 123,
+ RegisteredClaims: jwt.RegisteredClaims{
+ ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)),
+ },
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ tokenString, err := token.SignedString(jwtKey)
+ assert.NoError(t, err)
+
+ _, err = ValidateJWT(tokenString)
+ assert.Error(t, err)
+}
+
+func TestValidateJWT_WrongSecret(t *testing.T) {
+ os.Setenv("JWT_SECRET", "testsecret")
+ jwtKey = []byte(os.Getenv("JWT_SECRET"))
+
+ userID := 123
+ tokenString, err := GenerateJWT(userID)
+ assert.NoError(t, err)
+
+ // Set a different secret for validation.
+ wrongKey := []byte("wrongsecret")
+
+ claims := &Claims{}
+ _, err = jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
+ return wrongKey, nil
+ })
+
+ assert.Error(t, err)
+}
diff --git a/api/database/database.go b/api/database/database.go
new file mode 100644
index 0000000..e5ecca8
--- /dev/null
+++ b/api/database/database.go
@@ -0,0 +1,32 @@
+package database
+
+import (
+ "gorm.io/driver/postgres"
+ "gorm.io/gorm"
+
+ "errors"
+)
+
+// Global variable but private
+var db *gorm.DB = nil
+
+// Init the database from a DSN string which must be a valid PostgreSQL dsn.
+// Also, auto migrate all the models.
+func InitDb(dsn string) (*gorm.DB, error) {
+ var err error
+ db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
+
+ if err == nil {
+ db.AutoMigrate(&User{})
+ }
+
+ return db, err
+}
+
+// Return the instance or error if the config is not laoded yet
+func GetDb() (*gorm.DB, error) {
+ if db == nil {
+ return nil, errors.New("You must call `InitDb()` first.")
+ }
+ return db, nil
+}
diff --git a/api/database/models.go b/api/database/models.go
new file mode 100644
index 0000000..e309a36
--- /dev/null
+++ b/api/database/models.go
@@ -0,0 +1,11 @@
+package database
+
+import "time"
+
+type User struct {
+ ID int `json:"id"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
diff --git a/api/handlers/handlers.go b/api/handlers/handlers.go
new file mode 100644
index 0000000..7d5fd10
--- /dev/null
+++ b/api/handlers/handlers.go
@@ -0,0 +1,68 @@
+package handlers
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/boozec/rahanna/api/auth"
+ "github.com/boozec/rahanna/api/database"
+ "golang.org/x/crypto/bcrypt"
+)
+
+func RegisterUser(w http.ResponseWriter, r *http.Request) {
+ var user database.User
+ err := json.NewDecoder(r.Body).Decode(&user)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ user.Password = string(hashedPassword)
+
+ db, _ := database.GetDb()
+
+ result := db.Create(&user)
+ if result.Error != nil {
+ http.Error(w, result.Error.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ w.WriteHeader(http.StatusCreated)
+}
+
+func LoginUser(w http.ResponseWriter, r *http.Request) {
+ var inputUser database.User
+ err := json.NewDecoder(r.Body).Decode(&inputUser)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ var storedUser database.User
+
+ db, _ := database.GetDb()
+ result := db.Where("username = ?", inputUser.Username).First(&storedUser)
+ if result.Error != nil {
+ http.Error(w, "Invalid credentials", http.StatusUnauthorized)
+ return
+ }
+
+ err = bcrypt.CompareHashAndPassword([]byte(storedUser.Password), []byte(inputUser.Password))
+ if err != nil {
+ http.Error(w, "Invalid credentials", http.StatusUnauthorized)
+ return
+ }
+
+ token, err := auth.GenerateJWT(storedUser.ID)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ json.NewEncoder(w).Encode(map[string]string{"token": token})
+}
diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go
new file mode 100644
index 0000000..29ed8b6
--- /dev/null
+++ b/api/middleware/middleware.go
@@ -0,0 +1,36 @@
+package middleware
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "github.com/boozec/rahanna/api/auth"
+)
+
+func AuthMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ tokenString := r.Header.Get("Authorization")
+
+ payloadMap := map[string]string{"error": "unauthorized"}
+ payload, _ := json.Marshal(payloadMap)
+
+ if tokenString == "" {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusUnauthorized)
+
+ w.Write([]byte(payload))
+ return
+ }
+
+ _, err := auth.ValidateJWT(tokenString)
+ if err != nil {
+ w.WriteHeader(http.StatusUnauthorized)
+
+ payload, _ := json.Marshal(payloadMap)
+
+ w.Write([]byte(payload))
+ return
+ }
+ next.ServeHTTP(w, r)
+ })
+}