summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/api/main.go3
-rw-r--r--cmd/ui/main.go2
-rw-r--r--internal/api/auth/auth.go24
-rw-r--r--internal/api/auth/auth_test.go2
-rw-r--r--internal/api/handlers/handlers.go104
-rw-r--r--internal/api/middleware/middleware.go11
-rw-r--r--internal/logger/logger.go32
-rw-r--r--internal/network/ip.go2
-rw-r--r--pkg/ui/multiplayer/multiplayer.go32
-rw-r--r--pkg/ui/views/api.go4
-rw-r--r--pkg/ui/views/game.go23
-rw-r--r--pkg/ui/views/game_api.go11
-rw-r--r--pkg/ui/views/game_keymap.go10
-rw-r--r--pkg/ui/views/game_moves.go4
-rw-r--r--pkg/ui/views/game_util.go2
-rw-r--r--pkg/ui/views/play.go2
-rw-r--r--pkg/ui/views/play_api.go6
17 files changed, 147 insertions, 127 deletions
diff --git a/cmd/api/main.go b/cmd/api/main.go
index 4bd538a..6d6b354 100644
--- a/cmd/api/main.go
+++ b/cmd/api/main.go
@@ -14,7 +14,8 @@ import (
func main() {
database.InitDb(os.Getenv("DATABASE_URL"))
- log := logger.InitLogger("rahanna.log")
+ log := logger.InitLogger("rahanna.log", false)
+
r := mux.NewRouter()
r.HandleFunc("/auth/register", handlers.RegisterUser).Methods(http.MethodPost)
diff --git a/cmd/ui/main.go b/cmd/ui/main.go
index 97f894a..930adf2 100644
--- a/cmd/ui/main.go
+++ b/cmd/ui/main.go
@@ -10,7 +10,7 @@ import (
func main() {
views.ClearScreen()
- _ = logger.InitLogger("rahanna-ui.log")
+ _ = logger.InitLogger("rahanna-ui.log", true)
p := tea.NewProgram(views.NewRahannaModel(), tea.WithAltScreen())
diff --git a/internal/api/auth/auth.go b/internal/api/auth/auth.go
index b382beb..966a09c 100644
--- a/internal/api/auth/auth.go
+++ b/internal/api/auth/auth.go
@@ -7,17 +7,26 @@ import (
"time"
"github.com/golang-jwt/jwt/v5"
+ "gorm.io/gorm"
)
+// Key used for JWT encryption/decryption
var jwtKey = []byte(os.Getenv("JWT_SECRET"))
+// Kind of JWT token
+var TokenType = "Bearer"
+
+// Extends JWT Claims with the UserID field
type Claims struct {
UserID int `json:"user_id"`
jwt.RegisteredClaims
}
+// Generate a JWT token from an userID.
func GenerateJWT(userID int) (string, error) {
- expirationTime := time.Now().Add(5 * time.Hour)
+ // Set expiration date for the token to 90 days
+ expirationTime := time.Now().Add(90 * 24 * time.Hour)
+
claims := &Claims{
UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{
@@ -30,17 +39,21 @@ func GenerateJWT(userID int) (string, error) {
if err != nil {
return "", err
}
- return tokenString, nil
+ return TokenType + " " + tokenString, nil
}
+// Validate a JWT token for a kind of time
func ValidateJWT(tokenString string) (*Claims, error) {
claims := &Claims{}
- // A token has a form `Bearer ...`
tokenParts := strings.Split(tokenString, " ")
if len(tokenParts) != 2 {
return nil, errors.New("not valid JWT")
}
+ if tokenParts[0] != TokenType {
+ return nil, errors.New("not valid JWT type")
+ }
+
token, err := jwt.ParseWithClaims(tokenParts[1], claims, func(token *jwt.Token) (interface{}, error) {
return jwtKey, nil
})
@@ -55,3 +68,8 @@ func ValidateJWT(tokenString string) (*Claims, error) {
return claims, nil
}
+
+// Common omit password field for users
+func OmitPassword(db *gorm.DB) *gorm.DB {
+ return db.Omit("Password")
+}
diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go
index 66bcc27..50b6c9b 100644
--- a/internal/api/auth/auth_test.go
+++ b/internal/api/auth/auth_test.go
@@ -19,8 +19,6 @@ func TestGenerateAndValidateJWT(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, tokenString)
- tokenString = "Bearer " + tokenString
-
claims, err := ValidateJWT(tokenString)
assert.NoError(t, err)
assert.NotNil(t, claims)
diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go
index 41779c7..6d1b4e3 100644
--- a/internal/api/handlers/handlers.go
+++ b/internal/api/handlers/handlers.go
@@ -21,6 +21,7 @@ type NewGameRequest struct {
func RegisterUser(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /auth/register")
+
var user database.User
err := json.NewDecoder(r.Body).Decode(&user)
if err != nil {
@@ -35,9 +36,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) {
var storedUser database.User
db, _ := database.GetDb()
- result := db.Where("username = ?", user.Username).First(&storedUser)
-
- if result.Error == nil {
+ if result := db.Where("username = ?", user.Username).First(&storedUser); result.Error == nil {
JsonError(&w, "user with this username already exists")
return
}
@@ -49,8 +48,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) {
}
user.Password = string(hashedPassword)
- result = db.Create(&user)
- if result.Error != nil {
+ if result := db.Create(&user); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -67,6 +65,7 @@ func RegisterUser(w http.ResponseWriter, r *http.Request) {
func LoginUser(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /auth/login")
+
var inputUser database.User
err := json.NewDecoder(r.Body).Decode(&inputUser)
if err != nil {
@@ -77,8 +76,7 @@ func LoginUser(w http.ResponseWriter, r *http.Request) {
var storedUser database.User
db, _ := database.GetDb()
- result := db.Where("username = ?", inputUser.Username).First(&storedUser)
- if result.Error != nil {
+ if result := db.Where("username = ?", inputUser.Username).First(&storedUser); result.Error != nil {
JsonError(&w, "invalid credentials")
return
}
@@ -100,10 +98,10 @@ func LoginUser(w http.ResponseWriter, r *http.Request) {
func NewPlay(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /play")
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
@@ -116,11 +114,6 @@ func NewPlay(w http.ResponseWriter, r *http.Request) {
return
}
- if err != nil {
- JsonError(&w, err.Error())
- return
- }
-
db, _ := database.GetDb()
name := network.NewSession()
@@ -133,8 +126,7 @@ func NewPlay(w http.ResponseWriter, r *http.Request) {
Outcome: "*",
}
- result := db.Create(&play)
- if result.Error != nil {
+ if result := db.Create(&play); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -145,10 +137,10 @@ func NewPlay(w http.ResponseWriter, r *http.Request) {
func EnterGame(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("POST /enter-game")
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
@@ -162,17 +154,11 @@ func EnterGame(w http.ResponseWriter, r *http.Request) {
return
}
- if err != nil {
- JsonError(&w, err.Error())
- return
- }
-
db, _ := database.GetDb()
var game database.Game
- result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game)
- if result.Error != nil {
+ if result := db.Where("name = ? AND player2_id IS NULL", payload.Name).First(&game); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -186,13 +172,9 @@ func EnterGame(w http.ResponseWriter, r *http.Request) {
return
}
- result = db.Where("id = ?", game.ID).
- Preload("Player1", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
- Preload("Player2", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
+ result := db.Where("id = ?", game.ID).
+ Preload("Player1", auth.OmitPassword).
+ Preload("Player2", auth.OmitPassword).
First(&game)
if result.Error != nil {
@@ -207,17 +189,16 @@ func AllPlay(w http.ResponseWriter, r *http.Request) {
log, _ := logger.GetLogger()
log.Info("GET /play")
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
-
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
db, _ := database.GetDb()
var games []database.Game
- result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID).
+ if result := db.Where("player1_id = ? OR player2_id = ?", claims.UserID, claims.UserID).
Preload("Player1", func(db *gorm.DB) *gorm.DB {
return db.Omit("Password")
}).
@@ -225,9 +206,7 @@ func AllPlay(w http.ResponseWriter, r *http.Request) {
return db.Omit("Password")
}).
Order("updated_at DESC").
- Find(&games)
-
- if result.Error != nil {
+ Find(&games); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -241,26 +220,23 @@ func GetGameId(w http.ResponseWriter, r *http.Request) {
id := vars["id"]
log.Info(fmt.Sprintf("GET /play/%s", id))
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
-
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
db, _ := database.GetDb()
var game database.Game
- result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).
+ if result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).
Preload("Player1", func(db *gorm.DB) *gorm.DB {
return db.Omit("Password")
}).
Preload("Player2", func(db *gorm.DB) *gorm.DB {
return db.Omit("Password")
}).
- First(&game)
-
- if result.Error != nil {
+ First(&game); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -274,10 +250,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) {
id := vars["id"]
log.Info(fmt.Sprintf("POST /play/%s/end", id))
- claims, err := auth.ValidateJWT(r.Header.Get("Authorization"))
-
- if err != nil {
- JsonError(&w, err.Error())
+ claims, ok := r.Context().Value("claims").(*auth.Claims)
+ if !ok {
+ JsonError(&w, "claims not found")
return
}
@@ -290,18 +265,15 @@ func EndGame(w http.ResponseWriter, r *http.Request) {
return
}
- if err != nil {
- JsonError(&w, err.Error())
- return
- }
-
db, _ := database.GetDb()
var game database.Game
// FIXME: this is not secure
- result := db.Where("id = ? AND (player1_id = ? OR player2_id = ?)", id, claims.UserID, claims.UserID).First(&game)
- if result.Error != nil {
+ if result := db.Where(
+ "id = ? AND (player1_id = ? OR player2_id = ?)",
+ id, claims.UserID, claims.UserID,
+ ).First(&game); result.Error != nil {
JsonError(&w, result.Error.Error())
return
}
@@ -313,13 +285,9 @@ func EndGame(w http.ResponseWriter, r *http.Request) {
return
}
- result = db.Where("id = ?", game.ID).
- Preload("Player1", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
- Preload("Player2", func(db *gorm.DB) *gorm.DB {
- return db.Omit("Password")
- }).
+ result := db.Where("id = ?", game.ID).
+ Preload("Player1", auth.OmitPassword).
+ Preload("Player2", auth.OmitPassword).
First(&game)
if result.Error != nil {
diff --git a/internal/api/middleware/middleware.go b/internal/api/middleware/middleware.go
index 0334e78..d7c5a30 100644
--- a/internal/api/middleware/middleware.go
+++ b/internal/api/middleware/middleware.go
@@ -1,12 +1,16 @@
package middleware
import (
+ "context"
"encoding/json"
"net/http"
"github.com/boozec/rahanna/internal/api/auth"
)
+// AuthMiddleware ensures that the requester has passed the Authorization
+// header with a valid JWY token.
+// It passes the claims item via context
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Header.Get("Authorization")
@@ -22,7 +26,7 @@ func AuthMiddleware(next http.Handler) http.Handler {
return
}
- _, err := auth.ValidateJWT(tokenString)
+ claims, err := auth.ValidateJWT(tokenString)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
@@ -31,6 +35,9 @@ func AuthMiddleware(next http.Handler) http.Handler {
w.Write([]byte(payload))
return
}
- next.ServeHTTP(w, r)
+
+ ctx := context.WithValue(r.Context(), "claims", claims)
+
+ next.ServeHTTP(w, r.WithContext(ctx))
})
}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
index a5d0264..b7ecb86 100644
--- a/internal/logger/logger.go
+++ b/internal/logger/logger.go
@@ -2,6 +2,7 @@ package logger
import (
"errors"
+ "os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
@@ -10,31 +11,44 @@ import (
var logger *zap.Logger = nil
-func InitLogger(logFile string) *zap.Logger {
+// Set up a new Zap logger. If `onlyFile` is true, set up the logger to work
+// only on file, else prints on stdout
+func InitLogger(logFile string, onlyFile bool) *zap.Logger {
cfg := zap.NewProductionConfig()
- cfg.OutputPaths = []string{logFile}
- cfg.ErrorOutputPaths = []string{logFile}
+ cfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
+
+ var core zapcore.Core
- // Configure lumberjack for log rotation
lumberjackLogger := &lumberjack.Logger{
Filename: logFile,
- MaxSize: 100, // megabytes
+ MaxSize: 100,
MaxBackups: 5,
- MaxAge: 30, // days
+ MaxAge: 30,
Compress: true,
}
- core := zapcore.NewCore(
+ fileCore := zapcore.NewCore(
zapcore.NewJSONEncoder(cfg.EncoderConfig),
- zapcore.AddSync(lumberjackLogger), // Log only to the file via lumberjack
+ zapcore.AddSync(lumberjackLogger),
cfg.Level,
)
- logger = zap.New(core)
+ if onlyFile {
+ core = fileCore
+ } else {
+ consoleCore := zapcore.NewCore(
+ zapcore.NewConsoleEncoder(cfg.EncoderConfig),
+ zapcore.Lock(os.Stdout),
+ cfg.Level,
+ )
+ core = zapcore.NewTee(fileCore, consoleCore)
+ }
+ logger = zap.New(core)
return logger
}
+// Return the global Zap logger after calling `InitLogger` method
func GetLogger() (*zap.Logger, error) {
if logger == nil {
return nil, errors.New("You must call `InitLogger()` first.")
diff --git a/internal/network/ip.go b/internal/network/ip.go
index dcd15db..ec1e984 100644
--- a/internal/network/ip.go
+++ b/internal/network/ip.go
@@ -8,6 +8,7 @@ import (
"github.com/boozec/rahanna/internal/logger"
)
+// Connect a DNS to get the address
func GetOutboundIP() net.IP {
log, _ := logger.GetLogger()
conn, err := net.Dial("udp", "8.8.8.8:80")
@@ -21,6 +22,7 @@ func GetOutboundIP() net.IP {
return localAddr.IP
}
+// Returns a random available port on the node
func GetRandomAvailablePort() (int, error) {
for i := 0; i < 100; i += 1 {
port := rand.Intn(65535-1024) + 1024
diff --git a/pkg/ui/multiplayer/multiplayer.go b/pkg/ui/multiplayer/multiplayer.go
index 1680035..c9fc4b2 100644
--- a/pkg/ui/multiplayer/multiplayer.go
+++ b/pkg/ui/multiplayer/multiplayer.go
@@ -8,10 +8,12 @@ import (
)
type GameNetwork struct {
- Server *network.TCPNetwork
- Peer string
+ server *network.TCPNetwork
+ me network.NetworkID
+ peer network.NetworkID
}
+// Wrapper to a `TCPNetwork`
func NewGameNetwork(localID string, address string, onHandshake network.NetworkHandshakeFunc, logger *zap.Logger) *GameNetwork {
opts := network.TCPNetworkOpts{
ListenAddr: address,
@@ -20,9 +22,29 @@ func NewGameNetwork(localID string, address string, onHandshake network.NetworkH
Logger: logger,
}
server := network.NewTCPNetwork(network.NetworkID(localID), opts)
- peer := ""
return &GameNetwork{
- Server: server,
- Peer: peer,
+ server: server,
+ me: network.NetworkID(localID),
}
}
+
+func (n *GameNetwork) Peer() network.NetworkID {
+ return n.peer
+}
+
+func (n *GameNetwork) Me() network.NetworkID {
+ return n.me
+}
+
+func (n *GameNetwork) Send(payload []byte) error {
+ return n.server.Send(n.peer, payload)
+}
+
+func (n *GameNetwork) AddPeer(remoteID network.NetworkID, addr string) {
+ n.peer = remoteID
+ n.server.AddPeer(remoteID, addr)
+}
+
+func (n *GameNetwork) AddReceiveFunction(f network.NetworkMessageReceiveFunc) {
+ n.server.OnReceiveFn = f
+}
diff --git a/pkg/ui/views/api.go b/pkg/ui/views/api.go
index ba202c8..14f4cf2 100644
--- a/pkg/ui/views/api.go
+++ b/pkg/ui/views/api.go
@@ -38,7 +38,7 @@ func getUserID() (int, error) {
return -1, err
}
- claims, err := auth.ValidateJWT("Bearer " + token)
+ claims, err := auth.ValidateJWT(token)
if err != nil {
return -1, err
}
@@ -55,7 +55,7 @@ func sendAPIRequest(method, url string, payload []byte, authorization string) (*
}
req.Header.Set("Content-Type", "application/json")
- req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", authorization))
+ req.Header.Add("Authorization", authorization)
client := &http.Client{}
return client.Do(req)
diff --git a/pkg/ui/views/game.go b/pkg/ui/views/game.go
index 7075f8b..a97d068 100644
--- a/pkg/ui/views/game.go
+++ b/pkg/ui/views/game.go
@@ -5,7 +5,6 @@ import (
"strings"
"github.com/boozec/rahanna/internal/api/database"
- "github.com/boozec/rahanna/internal/network"
"github.com/boozec/rahanna/pkg/ui/multiplayer"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/textinput"
@@ -25,7 +24,6 @@ type GameModel struct {
keys gameKeyMap
// Game state
- peer string
currentGameID int
game *database.Game
network *multiplayer.GameNetwork
@@ -36,7 +34,7 @@ type GameModel struct {
}
// NewGameModel creates a new GameModel.
-func NewGameModel(width, height int, peer string, currentGameID int, network *multiplayer.GameNetwork) GameModel {
+func NewGameModel(width, height int, currentGameID int, network *multiplayer.GameNetwork) GameModel {
listDelegate := list.NewDefaultDelegate()
listDelegate.ShowDescription = false
listDelegate.Styles.SelectedTitle = lipgloss.NewStyle().
@@ -55,7 +53,6 @@ func NewGameModel(width, height int, peer string, currentGameID int, network *mu
width: width,
height: height,
keys: defaultGameKeyMap,
- peer: peer,
currentGameID: currentGameID,
network: network,
chessGame: chess.NewGame(chess.UseNotation(chess.UCINotation{})),
@@ -97,10 +94,10 @@ func (m GameModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
cmds = append(cmds, cmd, m.updateMovesListCmd())
case EndGameMsg:
if msg.abandoned {
- if m.peer == "peer-2" {
- m.game.Outcome = "1-0"
+ if m.network.Me() == "peer-1" {
+ m.game.Outcome = string(chess.WhiteWon)
} else {
- m.game.Outcome = "0-1"
+ m.game.Outcome = string(chess.BlackWon)
}
m, cmd = m.handleDatabaseGameMsg(*m.game)
cmds = append(cmds, cmd)
@@ -123,7 +120,7 @@ func (m GameModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.err = err
} else {
m.turn++
- m.network.Server.Send(network.NetworkID(m.peer), []byte(moveStr))
+ m.network.Send([]byte(moveStr))
m.err = nil
}
cmds = append(cmds, m.getMoves(), m.updateMovesListCmd())
@@ -172,17 +169,17 @@ func (m GameModel) View() string {
} else {
var outcome string
switch m.game.Outcome {
- case "1-0":
+ case string(chess.WhiteWon):
outcome = "White won"
- if m.peer == "peer-2" {
+ if m.network.Me() == "peer-1" {
outcome += " (YOU)"
}
- case "0-1":
+ case string(chess.BlackWon):
outcome = "Black won"
- if m.peer == "peer-1" {
+ if m.network.Me() == "peer-2" {
outcome += " (YOU)"
}
- case "1/2-1/2":
+ case string(chess.Draw):
outcome = "Draw"
default:
outcome = "NoOutcome"
diff --git a/pkg/ui/views/game_api.go b/pkg/ui/views/game_api.go
index 34ba1f3..485df41 100644
--- a/pkg/ui/views/game_api.go
+++ b/pkg/ui/views/game_api.go
@@ -12,11 +12,6 @@ import (
func (m GameModel) handleDatabaseGameMsg(msg database.Game) (GameModel, tea.Cmd) {
m.game = &msg
- if m.peer == "peer-2" {
- m.network.Peer = msg.IP2
- } else {
- m.network.Peer = msg.IP1
- }
var cmd tea.Cmd
@@ -52,15 +47,15 @@ func (m *GameModel) getGame() tea.Cmd {
}
// Establish peer connection
- if m.peer == "peer-2" {
+ if m.network.Me() == "peer-1" {
if game.IP2 != "" {
remote := game.IP2
- go m.network.Server.AddPeer("peer-2", remote)
+ go m.network.AddPeer("peer-2", remote)
}
} else {
if game.IP1 != "" {
remote := game.IP1
- go m.network.Server.AddPeer("peer-1", remote)
+ go m.network.AddPeer("peer-1", remote)
}
}
diff --git a/pkg/ui/views/game_keymap.go b/pkg/ui/views/game_keymap.go
index 29881c8..5c65a57 100644
--- a/pkg/ui/views/game_keymap.go
+++ b/pkg/ui/views/game_keymap.go
@@ -3,10 +3,10 @@ package views
import (
"fmt"
- "github.com/boozec/rahanna/internal/network"
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
+ "github.com/notnil/chess"
)
// gameKeyMap defines the key bindings for the game view.
@@ -36,13 +36,13 @@ func (m GameModel) handleKeyMsg(msg tea.KeyMsg) (GameModel, tea.Cmd) {
switch {
case key.Matches(msg, m.keys.Abandon):
var outcome string
- if m.peer == "peer-2" {
- outcome = "0-1"
+ if m.network.Me() == "peer-1" {
+ outcome = string(chess.BlackWon)
} else {
- outcome = "1-0"
+ outcome = string(chess.WhiteWon)
}
- m.network.Server.Send(network.NetworkID(m.peer), []byte("🏳️"))
+ m.network.Send([]byte("🏳️"))
return m, m.endGame(outcome)
case key.Matches(msg, m.keys.Quit):
return m, SwitchModelCmd(NewPlayModel(m.width, m.height))
diff --git a/pkg/ui/views/game_moves.go b/pkg/ui/views/game_moves.go
index 4ce3796..eeee9e1 100644
--- a/pkg/ui/views/game_moves.go
+++ b/pkg/ui/views/game_moves.go
@@ -24,10 +24,10 @@ func (i item) Description() string { return "" }
func (i item) FilterValue() string { return i.title }
func (m *GameModel) getMoves() tea.Cmd {
- m.network.Server.OnReceiveFn = func(msg network.Message) {
+ m.network.AddReceiveFunction(func(msg network.Message) {
payload := string(msg.Payload)
m.incomingMoves <- payload
- }
+ })
return func() tea.Msg {
move := <-m.incomingMoves
diff --git a/pkg/ui/views/game_util.go b/pkg/ui/views/game_util.go
index 7d83eee..eb904c5 100644
--- a/pkg/ui/views/game_util.go
+++ b/pkg/ui/views/game_util.go
@@ -24,5 +24,5 @@ func (m GameModel) buildWindowContent(content string, formWidth int) string {
}
func (m GameModel) isMyTurn() bool {
- return m.turn%2 == 0 && m.peer == "peer-2" || m.turn%2 == 1 && m.peer == "peer-1"
+ return m.turn%2 == 0 && m.network.Me() == "peer-1" || m.turn%2 == 1 && m.network.Me() == "peer-2"
}
diff --git a/pkg/ui/views/play.go b/pkg/ui/views/play.go
index ed5f1e1..20e2ebb 100644
--- a/pkg/ui/views/play.go
+++ b/pkg/ui/views/play.go
@@ -89,7 +89,7 @@ func (m PlayModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.userID, m.err = getUserID()
return m.handleGamesResponse(msg)
case StartGameMsg:
- return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, "peer-2", m.currentGameId, m.network))
+ return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, m.currentGameId, m.network))
case error:
return m.handleError(msg)
}
diff --git a/pkg/ui/views/play_api.go b/pkg/ui/views/play_api.go
index 7119ffa..40f26a8 100644
--- a/pkg/ui/views/play_api.go
+++ b/pkg/ui/views/play_api.go
@@ -68,11 +68,9 @@ func (m *PlayModel) handleGameResponse(msg database.Game) (tea.Model, tea.Cmd) {
localPort, _ := strconv.ParseInt(ip[1], 10, 32)
logger, _ := logger.GetLogger()
- network := multiplayer.NewGameNetwork("peer-2", fmt.Sprintf("%s:%d", localIP, localPort), func() error {
- return nil
- }, logger)
+ network := multiplayer.NewGameNetwork("peer-2", fmt.Sprintf("%s:%d", localIP, localPort), network.DefaultHandshake, logger)
- return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, "peer-1", m.game.ID, network))
+ return m, SwitchModelCmd(NewGameModel(m.width, m.height+1, m.game.ID, network))
}
return m, nil
}