Compare commits

...

4 Commits

Author SHA1 Message Date
Mumble
2724f4b6da
Merge 45eeba42f7 into 8e20875453 2024-09-20 09:50:08 -07:00
Frederico Santos
8e20875453
feat: Add admin Discord link endpoint (#49)
* feat: Add admin Discord link endpoint

* feat: Add Discord Guild ID flag to server configuration

* feat: Add logging for Discord ID addition in admin Discord link endpoint

* chore: Update variable name for Discord guild ID in account package

* chore: Add logging for Discord ID addition in admin Discord link endpoint

* chore: Add admin Discord link endpoint

* chore: Add logging for Discord ID addition in admin Discord link endpoint

* chore: Remove unnecessary code in handleAdminDiscordLink function

* chore: Update logging format in handleAdminDiscordLink function

* chore: Refactor handleAdminDiscordLink function for improved logging

* chore: Update Discord Bot Token and Discord Guild ID flags in server configuration

* chore: Refactor handleAccountInfo function for improved readability and error handling

* chore: Update server configuration flags for Discord Bot Token and Guild ID

* Refactor handleAdminDiscordLink function for improved error handling and logging

* feat: Add "Helper" role to Discord admin check for enhanced access control
2024-09-13 22:32:49 -04:00
Pancakes
a3fcec4a4c
Don't use Sprintf in Errorf for no reason 2024-09-10 21:39:07 -04:00
frutescens
45eeba42f7 No more longblob 2024-08-25 15:02:04 -07:00
11 changed files with 180 additions and 42 deletions

View File

@ -32,7 +32,7 @@ func ChangePW(uuid []byte, password string) error {
salt := make([]byte, ArgonSaltSize)
_, err := rand.Read(salt)
if err != nil {
return fmt.Errorf(fmt.Sprintf("failed to generate salt: %s", err))
return fmt.Errorf("failed to generate salt: %s", err)
}
err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)

View File

@ -22,12 +22,17 @@ import (
"errors"
"net/http"
"net/url"
"github.com/bwmarrin/discordgo"
)
var (
DiscordClientID string
DiscordClientSecret string
DiscordCallbackURL string
DiscordSession *discordgo.Session
DiscordGuildID string
)
func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
@ -36,7 +41,6 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", errors.New("code is empty")
}
discordId, err := RetrieveDiscordId(code)
if err != nil {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
@ -106,3 +110,34 @@ func RetrieveDiscordId(code string) (string, error) {
return user.Id, nil
}
func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
// fetch all roles from discord
roles, err := DiscordSession.GuildRoles(discordGuildID)
if err != nil {
return false, err
}
// fetch all roles from user
userRoles, err := DiscordSession.GuildMember(discordGuildID, discordId)
if err != nil {
return false, err
}
// check if user has a "Dev" or a "Division Heads" role
var hasRole bool
for _, role := range userRoles.Roles {
for _, guildRole := range roles {
if role == guildRole.ID && (guildRole.Name == "Dev" || guildRole.Name == "Division Heads" || guildRole.Name == "Helper") {
hasRole = true
break
}
}
}
if !hasRole {
return false, nil
}
return true, nil
}

View File

@ -26,16 +26,18 @@ type InfoResponse struct {
DiscordId string `json:"discordId"`
GoogleId string `json:"googleId"`
LastSessionSlot int `json:"lastSessionSlot"`
HasAdminRole bool `json:"hasAdminRole"`
}
// /account/info - get account info
func Info(username string, discordId string, googleId string, uuid []byte) (InfoResponse, error) {
func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
slot, _ := db.GetLatestSessionSaveDataSlot(uuid)
response := InfoResponse{
Username: username,
LastSessionSlot: slot,
DiscordId: discordId,
GoogleId: googleId,
HasAdminRole: hasAdminRole,
}
return response, nil
}

View File

@ -20,6 +20,7 @@ package account
import (
"crypto/rand"
"fmt"
"github.com/pagefaultgames/rogueserver/db"
)
@ -42,7 +43,7 @@ func Register(username, password string) error {
salt := make([]byte, ArgonSaltSize)
_, err = rand.Read(salt)
if err != nil {
return fmt.Errorf(fmt.Sprintf("failed to generate salt: %s", err))
return fmt.Errorf("failed to generate salt: %s", err)
}
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)

View File

@ -66,6 +66,10 @@ func Init(mux *http.ServeMux) error {
// auth
mux.HandleFunc("/auth/{provider}/callback", handleProviderCallback)
mux.HandleFunc("/auth/{provider}/logout", handleProviderLogout)
// admin
mux.HandleFunc("POST /admin/account/discord-link", handleAdminDiscordLink)
return nil
}

View File

@ -23,6 +23,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"strings"
@ -68,7 +69,13 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
return
}
}
response, err := account.Info(username, discordId, googleId, uuid)
var hasAdminRole bool
if discordId != "" {
hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
}
response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -660,3 +667,39 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(http.StatusOK)
}
func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
return
}
uuid, err := uuidFromRequest(r)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
}
err = db.AddDiscordIdByUsername(r.Form.Get("discordId"), r.Form.Get("username"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
log.Printf("%s: %s added discord id %s to username %s", r.URL.Path, userDiscordId, r.Form.Get("discordId"), r.Form.Get("username"))
w.WriteHeader(http.StatusOK)
}

View File

@ -116,6 +116,34 @@ func FetchGoogleIdByUsername(username string) (string, error) {
return googleId.String, nil
}
func FetchDiscordIdByUUID(uuid []byte) (string, error) {
var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
if err != nil {
return "", err
}
if !discordId.Valid {
return "", nil
}
return discordId.String, nil
}
func FetchGoogleIdByUUID(uuid []byte) (string, error) {
var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
if err != nil {
return "", err
}
if !googleId.Valid {
return "", nil
}
return googleId.String, nil
}
func FetchUsernameBySessionToken(token []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)

View File

@ -83,9 +83,9 @@ func setupDb(tx *sql.Tx) error {
`CREATE TABLE IF NOT EXISTS accountDailyRuns (uuid BINARY(16) NOT NULL, date DATE NOT NULL, score INT(11) NOT NULL DEFAULT 0, wave INT(11) NOT NULL DEFAULT 0, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (uuid, date), CONSTRAINT accountDailyRuns_ibfk_1 FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT accountDailyRuns_ibfk_2 FOREIGN KEY (date) REFERENCES dailyRuns (date) ON DELETE NO ACTION ON UPDATE NO ACTION)`,
`CREATE INDEX IF NOT EXISTS accountDailyRunsByDate ON accountDailyRuns (date)`,
`CREATE TABLE IF NOT EXISTS systemSaveData (uuid BINARY(16) PRIMARY KEY, data LONGBLOB, timestamp TIMESTAMP, FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
`CREATE TABLE IF NOT EXISTS systemSaveData (uuid BINARY(16) PRIMARY KEY, data MEDIUMBLOB, timestamp TIMESTAMP, FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
`CREATE TABLE IF NOT EXISTS sessionSaveData (uuid BINARY(16), slot TINYINT, data LONGBLOB, timestamp TIMESTAMP, PRIMARY KEY (uuid, slot), FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
`CREATE TABLE IF NOT EXISTS sessionSaveData (uuid BINARY(16), slot TINYINT, data BLOB, timestamp TIMESTAMP, PRIMARY KEY (uuid, slot), FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
// ----------------------------------
// MIGRATION 001

6
go.mod
View File

@ -13,4 +13,8 @@ require (
github.com/klauspost/compress v1.17.9
)
require golang.org/x/sys v0.19.0 // indirect
require (
github.com/bwmarrin/discordgo v0.28.1 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
golang.org/x/sys v0.19.0 // indirect
)

10
go.sum
View File

@ -1,12 +1,22 @@
github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4=
github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -19,66 +19,69 @@ package main
import (
"encoding/gob"
"flag"
"log"
"net"
"net/http"
"os"
"strconv"
"github.com/bwmarrin/discordgo"
"github.com/pagefaultgames/rogueserver/api"
"github.com/pagefaultgames/rogueserver/api/account"
"github.com/pagefaultgames/rogueserver/db"
)
func main() {
// flag stuff
debug := flag.Bool("debug", false, "use debug mode")
// env stuff
debug, _ := strconv.ParseBool(os.Getenv("debug"))
proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)")
addr := flag.String("addr", "0.0.0.0:8001", "network address for api to listen on")
tlscert := flag.String("tlscert", "", "tls certificate path")
tlskey := flag.String("tlskey", "", "tls key path")
proto := getEnv("proto", "tcp")
addr := getEnv("addr", "0.0.0.0:8001")
tlscert := getEnv("tlscert", "")
tlskey := getEnv("tlskey", "")
dbuser := flag.String("dbuser", "pokerogue", "database username")
dbpass := flag.String("dbpass", "pokerogue", "database password")
dbproto := flag.String("dbproto", "tcp", "protocol for database connection")
dbaddr := flag.String("dbaddr", "localhost", "database address")
dbname := flag.String("dbname", "pokeroguedb", "database name")
dbuser := getEnv("dbuser", "pokerogue")
dbpass := getEnv("dbpass", "pokerogue")
dbproto := getEnv("dbproto", "tcp")
dbaddr := getEnv("dbaddr", "localhost")
dbname := getEnv("dbname", "pokeroguedb")
discordclientid := flag.String("discordclientid", "dcid", "Discord Oauth2 Client ID")
discordsecretid := flag.String("discordsecretid", "dsid", "Discord Oauth2 Secret ID")
discordclientid := getEnv("discordclientid", "")
discordsecretid := getEnv("discordsecretid", "")
googleclientid := flag.String("googleclientid", "gcid", "Google Oauth2 Client ID")
googlesecretid := flag.String("googlesecretid", "gsid", "Google Oauth2 Secret ID")
googleclientid := getEnv("googleclientid", "")
googlesecretid := getEnv("googlesecretid", "")
callbackurl := flag.String("callbackurl", "http://localhost:8001/", "Callback URL for Oauth2 Client")
callbackurl := getEnv("callbackurl", "http://localhost:8001/")
gameurl := flag.String("gameurl", "https://pokerogue.net", "URL for game server")
gameurl := getEnv("gameurl", "https://pokerogue.net")
flag.Parse()
discordbottoken := getEnv("discordbottoken", "")
discordguildid := getEnv("discordguildid", "")
account.GameURL = *gameurl
account.GameURL = gameurl
account.DiscordClientID = *discordclientid
account.DiscordClientSecret = *discordsecretid
account.DiscordCallbackURL = *callbackurl + "/auth/discord/callback"
account.GoogleClientID = *googleclientid
account.GoogleClientSecret = *googlesecretid
account.GoogleCallbackURL = *callbackurl + "/auth/google/callback"
account.DiscordClientID = discordclientid
account.DiscordClientSecret = discordsecretid
account.DiscordCallbackURL = callbackurl + "/auth/discord/callback"
account.GoogleClientID = googleclientid
account.GoogleClientSecret = googlesecretid
account.GoogleCallbackURL = callbackurl + "/auth/google/callback"
account.DiscordSession, _ = discordgo.New("Bot " + discordbottoken)
account.DiscordGuildID = discordguildid
// register gob types
gob.Register([]interface{}{})
gob.Register(map[string]interface{}{})
// get database connection
err := db.Init(*dbuser, *dbpass, *dbproto, *dbaddr, *dbname)
err := db.Init(dbuser, dbpass, dbproto, dbaddr, dbname)
if err != nil {
log.Fatalf("failed to initialize database: %s", err)
}
// create listener
listener, err := createListener(*proto, *addr)
listener, err := createListener(proto, addr)
if err != nil {
log.Fatalf("failed to create net listener: %s", err)
}
@ -92,14 +95,14 @@ func main() {
// start web server
handler := prodHandler(mux, gameurl)
if *debug {
if debug {
handler = debugHandler(mux)
}
if *tlscert == "" {
if tlscert == "" {
err = http.Serve(listener, handler)
} else {
err = http.ServeTLS(listener, handler, *tlscert, *tlskey)
err = http.ServeTLS(listener, handler, tlscert, tlskey)
}
if err != nil {
log.Fatalf("failed to create http server or server errored: %s", err)
@ -126,11 +129,11 @@ func createListener(proto, addr string) (net.Listener, error) {
return listener, nil
}
func prodHandler(router *http.ServeMux, clienturl *string) http.Handler {
func prodHandler(router *http.ServeMux, clienturl string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST")
w.Header().Set("Access-Control-Allow-Origin", *clienturl)
w.Header().Set("Access-Control-Allow-Origin", clienturl)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
@ -155,3 +158,11 @@ func debugHandler(router *http.ServeMux) http.Handler {
router.ServeHTTP(w, r)
})
}
func getEnv(key string, defaultValue string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return defaultValue
}