summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorSanto Cariotti <santo@dcariotti.me>2025-04-17 22:08:43 +0200
committerSanto Cariotti <santo@dcariotti.me>2025-04-17 22:08:43 +0200
commit8255fbdd7d9d595e71545b7c6909114024527a34 (patch)
tree94773150af8b9d0a2b4e5b548923441cbc107b34 /internal
parent9cd48c660231592f3f8d9a035d45b568d987616e (diff)
Logger with also stdout and move logic to network.Me() instead of network.Peer()
Diffstat (limited to 'internal')
-rw-r--r--internal/api/auth/auth.go24
-rw-r--r--internal/api/auth/auth_test.go2
-rw-r--r--internal/api/handlers/handlers.go104
-rw-r--r--internal/api/middleware/middleware.go11
-rw-r--r--internal/logger/logger.go32
-rw-r--r--internal/network/ip.go2
6 files changed, 91 insertions, 84 deletions
diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go
index b382beb..966a09c 100644
--- a/internal/api/auth/auth.go
+++ b/internal/api/auth/auth.go
@@ -7,17 +7,26 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
+ "gorm.io/gorm"
)
+// Key used for JWT encryption/decryption
var jwtKey = []byte(os.Getenv("JWT_SECRET"))
+// Kind of JWT token
+var TokenType = "Bearer"
+
+// Extends JWT Claims with the UserID field
type Claims struct {
UserID int `json:"user_id"`
jwt.RegisteredClaims
}
+// Generate a JWT token from an userID.
func GenerateJWT(userID int) (string, error) {
- expirationTime := time.Now().Add(5 * time.Hour)
+ // Set expiration date for the token to 90 days
+ expirationTime := time.Now().Add(90 * 24 * time.Hour)
+
claims := &Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
@@ -30,17 +39,21 @@ func GenerateJWT(userID int) (string, error) {
if err != nil {
return "", err
}
- return tokenString, nil
+ return TokenType + " " + tokenString, nil
}
+// Validate a JWT token for a kind of time
func ValidateJWT(tokenString string) (*Claims, error) {
claims := &Claims{}
- // A token has a form `Bearer ...`
tokenParts := strings.Split(tokenString, " ")
if len(tokenParts) != 2 {
return nil, errors.New("not valid JWT")
}
+ if tokenParts[0] != TokenType {
+ return nil, errors.New("not valid JWT type")
+ }
+
token, err := jwt.ParseWithClaims(tokenParts[1], claims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
@@ -55,3 +68,8 @@ func ValidateJWT(tokenString string) (*Claims, error) {
return claims, nil
}
+
+// Common omit password field for users
+func OmitPassword(db *gorm.DB) *gorm.DB {
+ return db.Omit("Password")
+}
diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go
index 66bcc27..50b6c9b 100644
--- a/internal/api/auth/auth_test.go
+++ b/internal/api/auth/auth_test.go
@@ -19,8 +19,6 @@ func TestGenerateAndValidateJWT(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, tokenString)
- tokenString = "Bearer " + tokenString
-
claims, err := ValidateJWT(tokenString)
assert.NoError(t, err)
assert.NotNil(t, claims)
diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go
index 41779c7..6d1b4e3 100644
--- a/internal/api/handlers/handlers.go
+++ b/internal/api/handlers/handlers.go
@@ -21,6 +21,7 @@ type NewGameRequest struct {
func RegisterUser(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /auth/register")
+
var user database.User
err := json.NewDecoder(r.Body).Decode(&user)
if err != nil {
@@ -35,9 +36,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) {
var storedUser database.User
db, _ := database.GetDb()
- result := db.Where("username = ?", user.Username).First(&storedUser)
-
- if result.Error == nil {
+ if result := db.Where("username = ?", user.Username).First(&storedUser); result.Error == nil {
JsonError(&w, "user with this username already exists")
return
}
@@ -49,8 +48,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) {
}
user.Password = string(hashedPassword)
- result = db.Create(&user)
- if result.Error != nil {
+ if result := db.Create(&user); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -67,6 +65,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) {
func LoginUser(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /auth/login")
+
var inputUser database.User
err := json.NewDecoder(r.Body).Decode(&inputUser)
if err != nil {
@@ -77,8 +76,7 @@ func LoginUser(w http.ResponseWriter, r *http.Request) {
var storedUser database.User
db, _ := database.GetDb()
- result := db.Where("username = ?", inputUser.Username).First(&storedUser)
- if result.Error != nil {
+ if result := db.Where("username = ?", inputUser.Username).First(&storedUser); result.Error != nil {
JsonError(&w, "invalid credentials")
return
}
@@ -100,10 +98,10 @@ func LoginUser(w http.ResponseWriter, r *http.Request) {
func NewPlay(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /play")
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
@@ -116,11 +114,6 @@ func NewPlay(w http.ResponseWriter, r *http.Request) {
return
}
- if err != nil {
- JsonError(&w, err.Error())
- return
- }
-
db, _ := database.GetDb()
name := network.NewSession()
@@ -133,8 +126,7 @@ func NewPlay(w http.ResponseWriter, r *http.Request) {
Outcome: "*",
}
- result := db.Create(&play)
- if result.Error != nil {
+ if result := db.Create(&play); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -145,10 +137,10 @@ func NewPlay(w http.ResponseWriter, r *http.Request) {
func EnterGame(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /enter-game")
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
@@ -162,17 +154,11 @@ func EnterGame(w http.ResponseWriter, r *http.Request) {
return
}
- if err != nil {
- JsonError(&w, err.Error())
- return
- }
-
db, _ := database.GetDb()
var game database.Game
- result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game)
- if result.Error != nil {
+ if result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -186,13 +172,9 @@ func EnterGame(w http.ResponseWriter, r *http.Request) {
return
}
- result = db.Where("id = ?", game.ID).
- Preload("Player1", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
- Preload("Player2", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
+ result := db.Where("id = ?", game.ID).
+ Preload("Player1", auth.OmitPassword).
+ Preload("Player2", auth.OmitPassword).
First(&game)
if result.Error != nil {
@@ -207,17 +189,16 @@ func AllPlay(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("GET /play")
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
-
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
db, _ := database.GetDb()
var games []database.Game
- result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID).
+ if result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID).
Preload("Player1", func(db *gorm.DB) *gorm.DB {
return db.Omit("Password")
}).
@@ -225,9 +206,7 @@ func AllPlay(w http.ResponseWriter, r *http.Request) {
return db.Omit("Password")
}).
Order("updated_at DESC").
- Find(&games)
-
- if result.Error != nil {
+ Find(&games); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -241,26 +220,23 @@ func GetGameId(w http.ResponseWriter, r *http.Request) {
id := vars["id"]
log.Info(fmt.Sprintf("GET /play/%s", id))
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
-
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
db, _ := database.GetDb()
var game database.Game
- result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).
+ if result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).
Preload("Player1", func(db *gorm.DB) *gorm.DB {
return db.Omit("Password")
}).
Preload("Player2", func(db *gorm.DB) *gorm.DB {
return db.Omit("Password")
}).
- First(&game)
-
- if result.Error != nil {
+ First(&game); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -274,10 +250,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) {
id := vars["id"]
log.Info(fmt.Sprintf("POST /play/%s/end", id))
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
-
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
@@ -290,18 +265,15 @@ func EndGame(w http.ResponseWriter, r *http.Request) {
return
}
- if err != nil {
- JsonError(&w, err.Error())
- return
- }
-
db, _ := database.GetDb()
var game database.Game
// FIXME: this is not secure
- result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).First(&game)
- if result.Error != nil {
+ if result := db.Where(
+ "id = ? AND (player1_id = ? OR player2_id = ?)",
+ id, claims.UserID, claims.UserID,
+ ).First(&game); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -313,13 +285,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) {
return
}
- result = db.Where("id = ?", game.ID).
- Preload("Player1", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
- Preload("Player2", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
+ result := db.Where("id = ?", game.ID).
+ Preload("Player1", auth.OmitPassword).
+ Preload("Player2", auth.OmitPassword).
First(&game)
if result.Error != nil {
diff --git a/internal/api/middleware/middleware.go b/internal/api/middleware/middleware.go
index 0334e78..d7c5a30 100644
--- a/internal/api/middleware/middleware.go
+++ b/internal/api/middleware/middleware.go
@@ -1,12 +1,16 @@
package middleware
import (
+ "context"
"encoding/json"
"net/http"
"github.com/boozec/rahanna/internal/api/auth"
)
+// AuthMiddleware ensures that the requester has passed the Authorization
+// header with a valid JWY token.
+// It passes the claims item via context
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get("Authorization")
@@ -22,7 +26,7 @@ func AuthMiddleware(next http.Handler) http.Handler {
return
}
- _, err := auth.ValidateJWT(tokenString)
+ claims, err := auth.ValidateJWT(tokenString)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
@@ -31,6 +35,9 @@ func AuthMiddleware(next http.Handler) http.Handler {
w.Write([]byte(payload))
return
}
- next.ServeHTTP(w, r)
+
+ ctx := context.WithValue(r.Context(), "claims", claims)
+
+ next.ServeHTTP(w, r.WithContext(ctx))
})
}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
index a5d0264..b7ecb86 100644
--- a/internal/logger/logger.go
+++ b/internal/logger/logger.go
@@ -2,6 +2,7 @@ package logger
import (
"errors"
+ "os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -10,31 +11,44 @@ import (
var logger *zap.Logger = nil
-func InitLogger(logFile string) *zap.Logger {
+// Set up a new Zap logger. If `onlyFile` is true, set up the logger to work
+// only on file, else prints on stdout
+func InitLogger(logFile string, onlyFile bool) *zap.Logger {
cfg := zap.NewProductionConfig()
- cfg.OutputPaths = []string{logFile}
- cfg.ErrorOutputPaths = []string{logFile}
+ cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
+
+ var core zapcore.Core
- // Configure lumberjack for log rotation
lumberjackLogger := &lumberjack.Logger{
Filename: logFile,
- MaxSize: 100, // megabytes
+ MaxSize: 100,
MaxBackups: 5,
- MaxAge: 30, // days
+ MaxAge: 30,
Compress: true,
}
- core := zapcore.NewCore(
+ fileCore := zapcore.NewCore(
zapcore.NewJSONEncoder(cfg.EncoderConfig),
- zapcore.AddSync(lumberjackLogger), // Log only to the file via lumberjack
+ zapcore.AddSync(lumberjackLogger),
cfg.Level,
)
- logger = zap.New(core)
+ if onlyFile {
+ core = fileCore
+ } else {
+ consoleCore := zapcore.NewCore(
+ zapcore.NewConsoleEncoder(cfg.EncoderConfig),
+ zapcore.Lock(os.Stdout),
+ cfg.Level,
+ )
+ core = zapcore.NewTee(fileCore, consoleCore)
+ }
+ logger = zap.New(core)
return logger
}
+// Return the global Zap logger after calling `InitLogger` method
func GetLogger() (*zap.Logger, error) {
if logger == nil {
return nil, errors.New("You must call `InitLogger()` first.")
diff --git a/internal/network/ip.go b/internal/network/ip.go
index dcd15db..ec1e984 100644
--- a/internal/network/ip.go
+++ b/internal/network/ip.go
@@ -8,6 +8,7 @@ import (
"github.com/boozec/rahanna/internal/logger"
)
+// Connect a DNS to get the address
func GetOutboundIP() net.IP {
log, _ := logger.GetLogger()
conn, err := net.Dial("udp", "8.8.8.8:80")
@@ -21,6 +22,7 @@ func GetOutboundIP() net.IP {
return localAddr.IP
}
+// Returns a random available port on the node
func GetRandomAvailablePort() (int, error) {
for i := 0; i < 100; i += 1 {
port := rand.Intn(65535-1024) + 1024