mirror of
https://github.com/pagefaultgames/rogueserver.git
synced 2025-07-27 10:42:29 +02:00
Compare commits
4 Commits
de17208bde
...
2724f4b6da
Author | SHA1 | Date | |
---|---|---|---|
|
2724f4b6da | ||
|
8e20875453 | ||
|
a3fcec4a4c | ||
|
45eeba42f7 |
@ -32,7 +32,7 @@ func ChangePW(uuid []byte, password string) error {
|
||||
salt := make([]byte, ArgonSaltSize)
|
||||
_, err := rand.Read(salt)
|
||||
if err != nil {
|
||||
return fmt.Errorf(fmt.Sprintf("failed to generate salt: %s", err))
|
||||
return fmt.Errorf("failed to generate salt: %s", err)
|
||||
}
|
||||
|
||||
err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||
|
@ -22,12 +22,17 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
var (
|
||||
DiscordClientID string
|
||||
DiscordClientSecret string
|
||||
DiscordCallbackURL string
|
||||
|
||||
DiscordSession *discordgo.Session
|
||||
DiscordGuildID string
|
||||
)
|
||||
|
||||
func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
@ -36,7 +41,6 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
|
||||
http.Redirect(w, r, GameURL, http.StatusSeeOther)
|
||||
return "", errors.New("code is empty")
|
||||
}
|
||||
|
||||
discordId, err := RetrieveDiscordId(code)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, GameURL, http.StatusSeeOther)
|
||||
@ -106,3 +110,34 @@ func RetrieveDiscordId(code string) (string, error) {
|
||||
|
||||
return user.Id, nil
|
||||
}
|
||||
|
||||
func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
|
||||
// fetch all roles from discord
|
||||
roles, err := DiscordSession.GuildRoles(discordGuildID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// fetch all roles from user
|
||||
userRoles, err := DiscordSession.GuildMember(discordGuildID, discordId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if user has a "Dev" or a "Division Heads" role
|
||||
var hasRole bool
|
||||
for _, role := range userRoles.Roles {
|
||||
for _, guildRole := range roles {
|
||||
if role == guildRole.ID && (guildRole.Name == "Dev" || guildRole.Name == "Division Heads" || guildRole.Name == "Helper") {
|
||||
hasRole = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasRole {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
@ -26,16 +26,18 @@ type InfoResponse struct {
|
||||
DiscordId string `json:"discordId"`
|
||||
GoogleId string `json:"googleId"`
|
||||
LastSessionSlot int `json:"lastSessionSlot"`
|
||||
HasAdminRole bool `json:"hasAdminRole"`
|
||||
}
|
||||
|
||||
// /account/info - get account info
|
||||
func Info(username string, discordId string, googleId string, uuid []byte) (InfoResponse, error) {
|
||||
func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
|
||||
slot, _ := db.GetLatestSessionSaveDataSlot(uuid)
|
||||
response := InfoResponse{
|
||||
Username: username,
|
||||
LastSessionSlot: slot,
|
||||
DiscordId: discordId,
|
||||
GoogleId: googleId,
|
||||
HasAdminRole: hasAdminRole,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ package account
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
@ -42,7 +43,7 @@ func Register(username, password string) error {
|
||||
salt := make([]byte, ArgonSaltSize)
|
||||
_, err = rand.Read(salt)
|
||||
if err != nil {
|
||||
return fmt.Errorf(fmt.Sprintf("failed to generate salt: %s", err))
|
||||
return fmt.Errorf("failed to generate salt: %s", err)
|
||||
}
|
||||
|
||||
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||
|
@ -66,6 +66,10 @@ func Init(mux *http.ServeMux) error {
|
||||
// auth
|
||||
mux.HandleFunc("/auth/{provider}/callback", handleProviderCallback)
|
||||
mux.HandleFunc("/auth/{provider}/logout", handleProviderLogout)
|
||||
|
||||
// admin
|
||||
mux.HandleFunc("POST /admin/account/discord-link", handleAdminDiscordLink)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -68,7 +69,13 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
response, err := account.Info(username, discordId, googleId, uuid)
|
||||
|
||||
var hasAdminRole bool
|
||||
if discordId != "" {
|
||||
hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
|
||||
}
|
||||
|
||||
response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -660,3 +667,39 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
uuid, err := uuidFromRequest(r)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
if !hasRole || err != nil {
|
||||
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.AddDiscordIdByUsername(r.Form.Get("discordId"), r.Form.Get("username"))
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("%s: %s added discord id %s to username %s", r.URL.Path, userDiscordId, r.Form.Get("discordId"), r.Form.Get("username"))
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
@ -116,6 +116,34 @@ func FetchGoogleIdByUsername(username string) (string, error) {
|
||||
return googleId.String, nil
|
||||
}
|
||||
|
||||
func FetchDiscordIdByUUID(uuid []byte) (string, error) {
|
||||
var discordId sql.NullString
|
||||
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !discordId.Valid {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return discordId.String, nil
|
||||
}
|
||||
|
||||
func FetchGoogleIdByUUID(uuid []byte) (string, error) {
|
||||
var googleId sql.NullString
|
||||
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !googleId.Valid {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return googleId.String, nil
|
||||
}
|
||||
|
||||
func FetchUsernameBySessionToken(token []byte) (string, error) {
|
||||
var username string
|
||||
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)
|
||||
|
4
db/db.go
4
db/db.go
@ -83,9 +83,9 @@ func setupDb(tx *sql.Tx) error {
|
||||
`CREATE TABLE IF NOT EXISTS accountDailyRuns (uuid BINARY(16) NOT NULL, date DATE NOT NULL, score INT(11) NOT NULL DEFAULT 0, wave INT(11) NOT NULL DEFAULT 0, timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (uuid, date), CONSTRAINT accountDailyRuns_ibfk_1 FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT accountDailyRuns_ibfk_2 FOREIGN KEY (date) REFERENCES dailyRuns (date) ON DELETE NO ACTION ON UPDATE NO ACTION)`,
|
||||
`CREATE INDEX IF NOT EXISTS accountDailyRunsByDate ON accountDailyRuns (date)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS systemSaveData (uuid BINARY(16) PRIMARY KEY, data LONGBLOB, timestamp TIMESTAMP, FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
|
||||
`CREATE TABLE IF NOT EXISTS systemSaveData (uuid BINARY(16) PRIMARY KEY, data MEDIUMBLOB, timestamp TIMESTAMP, FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS sessionSaveData (uuid BINARY(16), slot TINYINT, data LONGBLOB, timestamp TIMESTAMP, PRIMARY KEY (uuid, slot), FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
|
||||
`CREATE TABLE IF NOT EXISTS sessionSaveData (uuid BINARY(16), slot TINYINT, data BLOB, timestamp TIMESTAMP, PRIMARY KEY (uuid, slot), FOREIGN KEY (uuid) REFERENCES accounts (uuid) ON DELETE CASCADE ON UPDATE CASCADE)`,
|
||||
|
||||
// ----------------------------------
|
||||
// MIGRATION 001
|
||||
|
6
go.mod
6
go.mod
@ -13,4 +13,8 @@ require (
|
||||
github.com/klauspost/compress v1.17.9
|
||||
)
|
||||
|
||||
require golang.org/x/sys v0.19.0 // indirect
|
||||
require (
|
||||
github.com/bwmarrin/discordgo v0.28.1 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
golang.org/x/sys v0.19.0 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -1,12 +1,22 @@
|
||||
github.com/bwmarrin/discordgo v0.28.1 h1:gXsuo2GBO7NbR6uqmrrBDplPUx2T3nzu775q/Rd1aG4=
|
||||
github.com/bwmarrin/discordgo v0.28.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
@ -19,66 +19,69 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/pagefaultgames/rogueserver/api"
|
||||
"github.com/pagefaultgames/rogueserver/api/account"
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// flag stuff
|
||||
debug := flag.Bool("debug", false, "use debug mode")
|
||||
// env stuff
|
||||
debug, _ := strconv.ParseBool(os.Getenv("debug"))
|
||||
|
||||
proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)")
|
||||
addr := flag.String("addr", "0.0.0.0:8001", "network address for api to listen on")
|
||||
tlscert := flag.String("tlscert", "", "tls certificate path")
|
||||
tlskey := flag.String("tlskey", "", "tls key path")
|
||||
proto := getEnv("proto", "tcp")
|
||||
addr := getEnv("addr", "0.0.0.0:8001")
|
||||
tlscert := getEnv("tlscert", "")
|
||||
tlskey := getEnv("tlskey", "")
|
||||
|
||||
dbuser := flag.String("dbuser", "pokerogue", "database username")
|
||||
dbpass := flag.String("dbpass", "pokerogue", "database password")
|
||||
dbproto := flag.String("dbproto", "tcp", "protocol for database connection")
|
||||
dbaddr := flag.String("dbaddr", "localhost", "database address")
|
||||
dbname := flag.String("dbname", "pokeroguedb", "database name")
|
||||
dbuser := getEnv("dbuser", "pokerogue")
|
||||
dbpass := getEnv("dbpass", "pokerogue")
|
||||
dbproto := getEnv("dbproto", "tcp")
|
||||
dbaddr := getEnv("dbaddr", "localhost")
|
||||
dbname := getEnv("dbname", "pokeroguedb")
|
||||
|
||||
discordclientid := flag.String("discordclientid", "dcid", "Discord Oauth2 Client ID")
|
||||
discordsecretid := flag.String("discordsecretid", "dsid", "Discord Oauth2 Secret ID")
|
||||
discordclientid := getEnv("discordclientid", "")
|
||||
discordsecretid := getEnv("discordsecretid", "")
|
||||
|
||||
googleclientid := flag.String("googleclientid", "gcid", "Google Oauth2 Client ID")
|
||||
googlesecretid := flag.String("googlesecretid", "gsid", "Google Oauth2 Secret ID")
|
||||
googleclientid := getEnv("googleclientid", "")
|
||||
googlesecretid := getEnv("googlesecretid", "")
|
||||
|
||||
callbackurl := flag.String("callbackurl", "http://localhost:8001/", "Callback URL for Oauth2 Client")
|
||||
callbackurl := getEnv("callbackurl", "http://localhost:8001/")
|
||||
|
||||
gameurl := flag.String("gameurl", "https://pokerogue.net", "URL for game server")
|
||||
gameurl := getEnv("gameurl", "https://pokerogue.net")
|
||||
|
||||
flag.Parse()
|
||||
discordbottoken := getEnv("discordbottoken", "")
|
||||
discordguildid := getEnv("discordguildid", "")
|
||||
|
||||
account.GameURL = *gameurl
|
||||
account.GameURL = gameurl
|
||||
|
||||
account.DiscordClientID = *discordclientid
|
||||
account.DiscordClientSecret = *discordsecretid
|
||||
account.DiscordCallbackURL = *callbackurl + "/auth/discord/callback"
|
||||
|
||||
account.GoogleClientID = *googleclientid
|
||||
account.GoogleClientSecret = *googlesecretid
|
||||
account.GoogleCallbackURL = *callbackurl + "/auth/google/callback"
|
||||
account.DiscordClientID = discordclientid
|
||||
account.DiscordClientSecret = discordsecretid
|
||||
account.DiscordCallbackURL = callbackurl + "/auth/discord/callback"
|
||||
|
||||
account.GoogleClientID = googleclientid
|
||||
account.GoogleClientSecret = googlesecretid
|
||||
account.GoogleCallbackURL = callbackurl + "/auth/google/callback"
|
||||
account.DiscordSession, _ = discordgo.New("Bot " + discordbottoken)
|
||||
account.DiscordGuildID = discordguildid
|
||||
// register gob types
|
||||
gob.Register([]interface{}{})
|
||||
gob.Register(map[string]interface{}{})
|
||||
|
||||
// get database connection
|
||||
err := db.Init(*dbuser, *dbpass, *dbproto, *dbaddr, *dbname)
|
||||
err := db.Init(dbuser, dbpass, dbproto, dbaddr, dbname)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize database: %s", err)
|
||||
}
|
||||
|
||||
// create listener
|
||||
listener, err := createListener(*proto, *addr)
|
||||
listener, err := createListener(proto, addr)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create net listener: %s", err)
|
||||
}
|
||||
@ -92,14 +95,14 @@ func main() {
|
||||
|
||||
// start web server
|
||||
handler := prodHandler(mux, gameurl)
|
||||
if *debug {
|
||||
if debug {
|
||||
handler = debugHandler(mux)
|
||||
}
|
||||
|
||||
if *tlscert == "" {
|
||||
if tlscert == "" {
|
||||
err = http.Serve(listener, handler)
|
||||
} else {
|
||||
err = http.ServeTLS(listener, handler, *tlscert, *tlskey)
|
||||
err = http.ServeTLS(listener, handler, tlscert, tlskey)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create http server or server errored: %s", err)
|
||||
@ -126,11 +129,11 @@ func createListener(proto, addr string) (net.Listener, error) {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func prodHandler(router *http.ServeMux, clienturl *string) http.Handler {
|
||||
func prodHandler(router *http.ServeMux, clienturl string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST")
|
||||
w.Header().Set("Access-Control-Allow-Origin", *clienturl)
|
||||
w.Header().Set("Access-Control-Allow-Origin", clienturl)
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@ -155,3 +158,11 @@ func debugHandler(router *http.ServeMux) http.Handler {
|
||||
router.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func getEnv(key string, defaultValue string) string {
|
||||
if value, ok := os.LookupEnv(key); ok {
|
||||
return value
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user