diff options
Diffstat (limited to 'api')
-rw-r--r-- | api/auth/auth.go | 57 | ||||
-rw-r--r-- | api/auth/auth_test.go | 74 | ||||
-rw-r--r-- | api/database/database.go | 32 | ||||
-rw-r--r-- | api/database/models.go | 24 | ||||
-rw-r--r-- | api/handlers/handlers.go | 200 | ||||
-rw-r--r-- | api/middleware/middleware.go | 36 |
6 files changed, 0 insertions, 423 deletions
diff --git a/api/auth/auth.go b/api/auth/auth.go deleted file mode 100644 index b382beb..0000000 --- a/api/auth/auth.go +++ /dev/null @@ -1,57 +0,0 @@ -package auth - -import ( - "errors" - "os" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -var jwtKey = []byte(os.Getenv("JWT_SECRET")) - -type Claims struct { - UserID int `json:"user_id"` - jwt.RegisteredClaims -} - -func GenerateJWT(userID int) (string, error) { - expirationTime := time.Now().Add(5 * time.Hour) - claims := &Claims{ - UserID: userID, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expirationTime), - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString(jwtKey) - if err != nil { - return "", err - } - return tokenString, nil -} - -func ValidateJWT(tokenString string) (*Claims, error) { - claims := &Claims{} - // A token has a form `Bearer ...` - tokenParts := strings.Split(tokenString, " ") - if len(tokenParts) != 2 { - return nil, errors.New("not valid JWT") - } - - token, err := jwt.ParseWithClaims(tokenParts[1], claims, func(token *jwt.Token) (interface{}, error) { - return jwtKey, nil - }) - - if err != nil { - return nil, err - } - - if !token.Valid { - return nil, err - } - - return claims, nil -} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go deleted file mode 100644 index 50b6c9b..0000000 --- a/api/auth/auth_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package auth - -import ( - "os" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/stretchr/testify/assert" -) - -func TestGenerateAndValidateJWT(t *testing.T) { - // Set up the JWT secret for the test. - os.Setenv("JWT_SECRET", "testsecret") - jwtKey = []byte(os.Getenv("JWT_SECRET")) - - userID := 123 - tokenString, err := GenerateJWT(userID) - assert.NoError(t, err) - assert.NotEmpty(t, tokenString) - - claims, err := ValidateJWT(tokenString) - assert.NoError(t, err) - assert.NotNil(t, claims) - assert.Equal(t, userID, claims.UserID) - assert.True(t, claims.ExpiresAt.After(time.Now())) -} - -func TestValidateJWT_InvalidToken(t *testing.T) { - os.Setenv("JWT_SECRET", "testsecret") - jwtKey = []byte(os.Getenv("JWT_SECRET")) - - _, err := ValidateJWT("invalid_token") - assert.Error(t, err) -} - -func TestValidateJWT_ExpiredToken(t *testing.T) { - os.Setenv("JWT_SECRET", "testsecret") - jwtKey = []byte(os.Getenv("JWT_SECRET")) - - // Create a token that has already expired. - claims := &Claims{ - UserID: 123, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Hour)), - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString(jwtKey) - assert.NoError(t, err) - - _, err = ValidateJWT(tokenString) - assert.Error(t, err) -} - -func TestValidateJWT_WrongSecret(t *testing.T) { - os.Setenv("JWT_SECRET", "testsecret") - jwtKey = []byte(os.Getenv("JWT_SECRET")) - - userID := 123 - tokenString, err := GenerateJWT(userID) - assert.NoError(t, err) - - // Set a different secret for validation. - wrongKey := []byte("wrongsecret") - - claims := &Claims{} - _, err = jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { - return wrongKey, nil - }) - - assert.Error(t, err) -} diff --git a/api/database/database.go b/api/database/database.go deleted file mode 100644 index 4470c58..0000000 --- a/api/database/database.go +++ /dev/null @@ -1,32 +0,0 @@ -package database - -import ( - "gorm.io/driver/postgres" - "gorm.io/gorm" - - "errors" -) - -// Global variable but private -var db *gorm.DB = nil - -// Init the database from a DSN string which must be a valid PostgreSQL dsn. -// Also, auto migrate all the models. -func InitDb(dsn string) (*gorm.DB, error) { - var err error - db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) - - if err == nil { - db.AutoMigrate(&User{}, &Game{}) - } - - return db, err -} - -// Return the instance or error if the config is not laoded yet -func GetDb() (*gorm.DB, error) { - if db == nil { - return nil, errors.New("You must call `InitDb()` first.") - } - return db, nil -} diff --git a/api/database/models.go b/api/database/models.go deleted file mode 100644 index a6e76c5..0000000 --- a/api/database/models.go +++ /dev/null @@ -1,24 +0,0 @@ -package database - -import "time" - -type User struct { - ID int `json:"id"` - Username string `json:"username"` - Password string `json:"password"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -type Game struct { - ID int `json:"id"` - Player1ID int `json:"-"` - Player1 User `gorm:"foreignKey:Player1ID" json:"player1"` - Player2ID *int `json:"-"` - Player2 *User `gorm:"foreignKey:Player2ID;null" json:"player2"` - Name string `json:"name"` - IP1 string `json:"ip1"` - IP2 string `json:"ip2"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} diff --git a/api/handlers/handlers.go b/api/handlers/handlers.go deleted file mode 100644 index 601b770..0000000 --- a/api/handlers/handlers.go +++ /dev/null @@ -1,200 +0,0 @@ -package handlers - -import ( - "encoding/json" - "log/slog" - "net/http" - "time" - - "github.com/boozec/rahanna/api/auth" - "github.com/boozec/rahanna/api/database" - "github.com/boozec/rahanna/network" - utils "github.com/boozec/rahanna/pkg" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" -) - -type NewGameRequest struct { - IP string `json:"ip"` -} - -func RegisterUser(w http.ResponseWriter, r *http.Request) { - slog.Info("POST /auth/register") - var user database.User - err := json.NewDecoder(r.Body).Decode(&user) - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - if len(user.Password) < 4 { - utils.JsonError(&w, "password too short") - return - } - - var storedUser database.User - db, _ := database.GetDb() - result := db.Where("username = ?", user.Username).First(&storedUser) - - if result.Error == nil { - utils.JsonError(&w, "user with this username already exists") - return - } - - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - user.Password = string(hashedPassword) - - result = db.Create(&user) - if result.Error != nil { - utils.JsonError(&w, result.Error.Error()) - return - } - - token, err := auth.GenerateJWT(user.ID) - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - json.NewEncoder(w).Encode(map[string]string{"token": token}) -} - -func LoginUser(w http.ResponseWriter, r *http.Request) { - slog.Info("POST /auth/login") - var inputUser database.User - err := json.NewDecoder(r.Body).Decode(&inputUser) - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - var storedUser database.User - - db, _ := database.GetDb() - result := db.Where("username = ?", inputUser.Username).First(&storedUser) - if result.Error != nil { - utils.JsonError(&w, "invalid credentials") - return - } - - err = bcrypt.CompareHashAndPassword([]byte(storedUser.Password), []byte(inputUser.Password)) - if err != nil { - utils.JsonError(&w, "invalid credentials") - return - } - - token, err := auth.GenerateJWT(storedUser.ID) - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - json.NewEncoder(w).Encode(map[string]string{"token": token}) -} - -func NewPlay(w http.ResponseWriter, r *http.Request) { - slog.Info("POST /play") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - var payload struct { - IP string `json:"ip"` - } - - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - utils.JsonError(&w, err.Error()) - return - } - - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - db, _ := database.GetDb() - - name := network.NewSession() - play := database.Game{ - Player1ID: claims.UserID, - Player2ID: nil, - Name: name, - IP1: payload.IP, - IP2: "", - } - - result := db.Create(&play) - if result.Error != nil { - utils.JsonError(&w, result.Error.Error()) - return - } - - json.NewEncoder(w).Encode(map[string]string{"name": name}) -} - -func EnterGame(w http.ResponseWriter, r *http.Request) { - slog.Info("POST /enter-game") - claims, err := auth.ValidateJWT(r.Header.Get("Authorization")) - - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - var payload struct { - Name string `json:"name"` - IP string `json:"ip"` - } - - if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - utils.JsonError(&w, err.Error()) - return - } - - if err != nil { - utils.JsonError(&w, err.Error()) - return - } - - db, _ := database.GetDb() - - var play database.Game - - result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&play) - if result.Error != nil { - utils.JsonError(&w, result.Error.Error()) - return - } - - play.Player2ID = &claims.UserID - play.IP2 = payload.IP - play.UpdatedAt = time.Now() - - if err := db.Save(&play).Error; err != nil { - utils.JsonError(&w, err.Error()) - return - } - - result = db.Where("id = ?", play.ID). - Preload("Player1", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). - Preload("Player2", func(db *gorm.DB) *gorm.DB { - return db.Omit("Password") - }). - First(&play) - - if result.Error != nil { - utils.JsonError(&w, result.Error.Error()) - return - } - - json.NewEncoder(w).Encode(play) -} diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go deleted file mode 100644 index 29ed8b6..0000000 --- a/api/middleware/middleware.go +++ /dev/null @@ -1,36 +0,0 @@ -package middleware - -import ( - "encoding/json" - "net/http" - - "github.com/boozec/rahanna/api/auth" -) - -func AuthMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tokenString := r.Header.Get("Authorization") - - payloadMap := map[string]string{"error": "unauthorized"} - payload, _ := json.Marshal(payloadMap) - - if tokenString == "" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - - w.Write([]byte(payload)) - return - } - - _, err := auth.ValidateJWT(tokenString) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - - payload, _ := json.Marshal(payloadMap) - - w.Write([]byte(payload)) - return - } - next.ServeHTTP(w, r) - }) -} |