From 7c0a777737ce9eb381f73197506fd8c2af281446 Mon Sep 17 00:00:00 2001 From: Chris Sexton Date: Wed, 3 Aug 2022 21:09:04 -0400 Subject: [PATCH] twitch: add reauthentication - refactored secrets to be in config - added missing format string to bot --- bot/bot.go | 2 +- config/config.go | 28 ++++++++- plugins/secrets/secrets.go | 31 +--------- plugins/twitch/twitch.go | 116 +++++++++++++++++++++++++++++++------ 4 files changed, 128 insertions(+), 49 deletions(-) diff --git a/bot/bot.go b/bot/bot.go index 02936f3..9640caf 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -265,7 +265,7 @@ func (b *bot) CheckAdmin(ID string) bool { log.Info().Interface("admins", admins).Msgf("Checking admin for %s", ID) for _, u := range admins { if ID == u { - log.Info().Msg("%s admin check: passed") + log.Info().Msgf("%s admin check: passed", u) return true } } diff --git a/config/config.go b/config/config.go index a9d58a2..eeecd06 100644 --- a/config/config.go +++ b/config/config.go @@ -201,6 +201,32 @@ func (c *Config) SecretKeys() []string { return keys } +func (c *Config) setSecret(key, value string) error { + q := `insert into secrets (key,value) values (?, ?) + on conflict(key) do update set value=?;` + _, err := c.Exec(q, key, value, value) + if err != nil { + log.Fatal().Err(err).Msgf("secret") + return err + } + return c.RefreshSecrets() +} + +// RegisterSecret creates a new secret +func (c *Config) RegisterSecret(key, value string) error { + return c.setSecret(key, value) +} + +// RemoveSecret deregisters a secret +func (c *Config) RemoveSecret(key string) error { + q := `delete from secrets where key=?` + _, err := c.Exec(q, key) + if err != nil { + return err + } + return c.RefreshSecrets() +} + func (c *Config) SetMap(key string, values map[string]string) error { b, err := json.Marshal(values) if err != nil { @@ -256,7 +282,7 @@ func ReadConfig(dbpath string) *Config { value string, primary key (key) );`); err != nil { - log.Fatal().Err(err).Msgf("failed to initialize config") + log.Fatal().Err(err).Msgf("failed to initialize secrets") } if err := c.RefreshSecrets(); err != nil { diff --git a/plugins/secrets/secrets.go b/plugins/secrets/secrets.go index 9e10954..8ceb187 100644 --- a/plugins/secrets/secrets.go +++ b/plugins/secrets/secrets.go @@ -46,33 +46,6 @@ func (p *SecretsPlugin) registerWeb() { p.b.RegisterWebName(r, "/secrets", "Secrets") } -func (p *SecretsPlugin) registerSecret(key, value string) error { - q := `insert into secrets (key, value) values (?, ?)` - _, err := p.db.Exec(q, key, value) - if err != nil { - return err - } - return p.c.RefreshSecrets() -} - -func (p *SecretsPlugin) removeSecret(key string) error { - q := `delete from secrets where key=?` - _, err := p.db.Exec(q, key) - if err != nil { - return err - } - return p.c.RefreshSecrets() -} - -func (p *SecretsPlugin) updateSecret(key, value string) error { - q := `update secrets set value=? where key=?` - _, err := p.db.Exec(q, value, key) - if err != nil { - return err - } - return p.c.RefreshSecrets() -} - func mkCheckError(w http.ResponseWriter) func(error) bool { return func(err error) bool { if err != nil { @@ -130,7 +103,7 @@ func (p *SecretsPlugin) handleRegister(w http.ResponseWriter, r *http.Request) { return } log.Debug().Msgf("Secret: %s", secret) - err = p.registerSecret(secret.Key, secret.Value) + err = p.c.RegisterSecret(secret.Key, secret.Value) if checkError(err) { return } @@ -148,7 +121,7 @@ func (p *SecretsPlugin) handleRemove(w http.ResponseWriter, r *http.Request) { if checkError(err) { return } - err = p.removeSecret(secret.Key) + err = p.c.RemoveSecret(secret.Key) if checkError(err) { return } diff --git a/plugins/twitch/twitch.go b/plugins/twitch/twitch.go index 92537bb..85b8e9d 100644 --- a/plugins/twitch/twitch.go +++ b/plugins/twitch/twitch.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/go-chi/chi/v5" "io/ioutil" "net/http" "net/url" @@ -12,6 +11,8 @@ import ( "text/template" "time" + "github.com/go-chi/chi/v5" + "github.com/rs/zerolog/log" "github.com/velour/catbase/bot" "github.com/velour/catbase/bot/msg" @@ -75,9 +76,11 @@ func New(b bot.Bot) *TwitchPlugin { } } } - go p.twitchLoop(b.DefaultConnector(), ch) + go p.twitchChannelLoop(b.DefaultConnector(), ch) } + go p.twitchAuthLoop(b.DefaultConnector()) + b.Register(p, bot.Message, p.message) b.Register(p, bot.Help, p.help) p.registerWeb() @@ -128,7 +131,10 @@ func (p *TwitchPlugin) message(c bot.Connector, kind bot.Kind, message msg.Messa if users := p.config.GetArray("Twitch."+channel+".Users", []string{}); len(users) > 0 { for _, twitcherName := range users { if _, ok := p.twitchList[twitcherName]; ok { - p.checkTwitch(c, channel, p.twitchList[twitcherName], true) + err := p.checkTwitch(c, channel, p.twitchList[twitcherName], true) + if err != nil { + log.Error().Err(err).Msgf("error in checking twitch") + } } } } @@ -153,25 +159,52 @@ func (p *TwitchPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, return true } -func (p *TwitchPlugin) twitchLoop(c bot.Connector, channel string) { +func (p *TwitchPlugin) twitchAuthLoop(c bot.Connector) { + frequency := p.config.GetInt("Twitch.AuthFreq", 60*60) + cid := p.config.Get("twitch.clientid", "") + secret := p.config.Get("twitch.secret", "") + if cid == "" || secret == "" { + log.Info().Msgf("Disabling twitch autoauth.") + return + } + + log.Info().Msgf("Checking auth every %d seconds", frequency) + + if err := p.validateCredentials(); err != nil { + log.Error().Err(err).Msgf("error checking twitch validity") + } + + for { + select { + case <-time.After(time.Duration(frequency) * time.Second): + if err := p.validateCredentials(); err != nil { + log.Error().Err(err).Msgf("error checking twitch validity") + } + } + } +} + +func (p *TwitchPlugin) twitchChannelLoop(c bot.Connector, channel string) { frequency := p.config.GetInt("Twitch.Freq", 60) - if p.config.Get("twitch.clientid", "") == "" || p.config.Get("twitch.token", "") == "" { + if p.config.Get("twitch.clientid", "") == "" || p.config.Get("twitch.secret", "") == "" { log.Info().Msgf("Disabling twitch autochecking.") return } - log.Info().Msgf("Checking every %d seconds", frequency) + log.Info().Msgf("Checking channels every %d seconds", frequency) for { time.Sleep(time.Duration(frequency) * time.Second) for _, twitcherName := range p.config.GetArray("Twitch."+channel+".Users", []string{}) { - p.checkTwitch(c, channel, p.twitchList[twitcherName], false) + if err := p.checkTwitch(c, channel, p.twitchList[twitcherName], false); err != nil { + log.Error().Err(err).Msgf("error in twitch loop") + } } } } -func getRequest(url, clientID, token string) ([]byte, bool) { +func getRequest(url, clientID, token string) ([]byte, int, bool) { bearer := fmt.Sprintf("Bearer %s", token) var body []byte var resp *http.Response @@ -193,18 +226,19 @@ func getRequest(url, clientID, token string) ([]byte, bool) { if err != nil { goto errCase } - return body, true + return body, resp.StatusCode, true errCase: log.Error().Err(err) - return []byte{}, false + return []byte{}, resp.StatusCode, false } -func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Twitcher, alwaysPrintStatus bool) { +func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Twitcher, alwaysPrintStatus bool) error { baseURL, err := url.Parse("https://api.twitch.tv/helix/streams") if err != nil { - log.Error().Msg("Error parsing twitch stream URL") - return + err := fmt.Errorf("error parsing twitch stream URL") + log.Error().Msg(err.Error()) + return err } query := baseURL.Query() @@ -216,25 +250,28 @@ func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Tw token := p.config.Get("twitch.token", "") if cid == token && cid == "" { log.Info().Msgf("Twitch plugin not enabled.") - return + return nil } - body, ok := getRequest(baseURL.String(), cid, token) + body, status, ok := getRequest(baseURL.String(), cid, token) if !ok { - return + return fmt.Errorf("got status %d: %s", status, string(body)) } var s stream err = json.Unmarshal(body, &s) if err != nil { - log.Error().Err(err) - return + log.Error().Err(err).Msgf("error reading twitch data") + return err } games := s.Data gameID, title := "", "" if len(games) > 0 { gameID = games[0].GameID + if gameID == "" { + gameID = "unknown" + } title = games[0].Title } @@ -298,4 +335,47 @@ func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Tw } twitcher.gameID = gameID } + return nil +} + +func (p *TwitchPlugin) validateCredentials() error { + cid := p.config.Get("twitch.clientid", "") + token := p.config.Get("twitch.token", "") + if token == "" { + return p.reAuthenticate() + } + _, status, ok := getRequest("https://id.twitch.tv/oauth2/validate", cid, token) + if !ok || status == http.StatusUnauthorized { + return p.reAuthenticate() + } + log.Debug().Msgf("checked credentials and they were valid") + return nil +} + +func (p *TwitchPlugin) reAuthenticate() error { + cid := p.config.Get("twitch.clientid", "") + secret := p.config.Get("twitch.secret", "") + if cid == "" || secret == "" { + return fmt.Errorf("could not request a new token without config values set") + } + resp, err := http.PostForm("https://id.twitch.tv/oauth2/token", url.Values{ + "client_id": {cid}, + "client_secret": {secret}, + "grant_type": {"client_credentials"}, + }) + if err != nil { + return err + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + credentials := struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + }{} + err = json.Unmarshal(body, &credentials) + log.Debug().Int("expires", credentials.ExpiresIn).Msgf("setting new twitch token") + return p.config.RegisterSecret("twitch.token", credentials.AccessToken) }