mirror of https://github.com/velour/catbase.git
twitch: add reauthentication
- refactored secrets to be in config - added missing format string to bot
This commit is contained in:
parent
45103cec62
commit
7c0a777737
|
@ -265,7 +265,7 @@ func (b *bot) CheckAdmin(ID string) bool {
|
||||||
log.Info().Interface("admins", admins).Msgf("Checking admin for %s", ID)
|
log.Info().Interface("admins", admins).Msgf("Checking admin for %s", ID)
|
||||||
for _, u := range admins {
|
for _, u := range admins {
|
||||||
if ID == u {
|
if ID == u {
|
||||||
log.Info().Msg("%s admin check: passed")
|
log.Info().Msgf("%s admin check: passed", u)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -201,6 +201,32 @@ func (c *Config) SecretKeys() []string {
|
||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) setSecret(key, value string) error {
|
||||||
|
q := `insert into secrets (key,value) values (?, ?)
|
||||||
|
on conflict(key) do update set value=?;`
|
||||||
|
_, err := c.Exec(q, key, value, value)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().Err(err).Msgf("secret")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.RefreshSecrets()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterSecret creates a new secret
|
||||||
|
func (c *Config) RegisterSecret(key, value string) error {
|
||||||
|
return c.setSecret(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveSecret deregisters a secret
|
||||||
|
func (c *Config) RemoveSecret(key string) error {
|
||||||
|
q := `delete from secrets where key=?`
|
||||||
|
_, err := c.Exec(q, key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.RefreshSecrets()
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Config) SetMap(key string, values map[string]string) error {
|
func (c *Config) SetMap(key string, values map[string]string) error {
|
||||||
b, err := json.Marshal(values)
|
b, err := json.Marshal(values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -256,7 +282,7 @@ func ReadConfig(dbpath string) *Config {
|
||||||
value string,
|
value string,
|
||||||
primary key (key)
|
primary key (key)
|
||||||
);`); err != nil {
|
);`); err != nil {
|
||||||
log.Fatal().Err(err).Msgf("failed to initialize config")
|
log.Fatal().Err(err).Msgf("failed to initialize secrets")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.RefreshSecrets(); err != nil {
|
if err := c.RefreshSecrets(); err != nil {
|
||||||
|
|
|
@ -46,33 +46,6 @@ func (p *SecretsPlugin) registerWeb() {
|
||||||
p.b.RegisterWebName(r, "/secrets", "Secrets")
|
p.b.RegisterWebName(r, "/secrets", "Secrets")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *SecretsPlugin) registerSecret(key, value string) error {
|
|
||||||
q := `insert into secrets (key, value) values (?, ?)`
|
|
||||||
_, err := p.db.Exec(q, key, value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.c.RefreshSecrets()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SecretsPlugin) removeSecret(key string) error {
|
|
||||||
q := `delete from secrets where key=?`
|
|
||||||
_, err := p.db.Exec(q, key)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.c.RefreshSecrets()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *SecretsPlugin) updateSecret(key, value string) error {
|
|
||||||
q := `update secrets set value=? where key=?`
|
|
||||||
_, err := p.db.Exec(q, value, key)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.c.RefreshSecrets()
|
|
||||||
}
|
|
||||||
|
|
||||||
func mkCheckError(w http.ResponseWriter) func(error) bool {
|
func mkCheckError(w http.ResponseWriter) func(error) bool {
|
||||||
return func(err error) bool {
|
return func(err error) bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -130,7 +103,7 @@ func (p *SecretsPlugin) handleRegister(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debug().Msgf("Secret: %s", secret)
|
log.Debug().Msgf("Secret: %s", secret)
|
||||||
err = p.registerSecret(secret.Key, secret.Value)
|
err = p.c.RegisterSecret(secret.Key, secret.Value)
|
||||||
if checkError(err) {
|
if checkError(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -148,7 +121,7 @@ func (p *SecretsPlugin) handleRemove(w http.ResponseWriter, r *http.Request) {
|
||||||
if checkError(err) {
|
if checkError(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = p.removeSecret(secret.Key)
|
err = p.c.RemoveSecret(secret.Key)
|
||||||
if checkError(err) {
|
if checkError(err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -12,6 +11,8 @@ import (
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/velour/catbase/bot"
|
"github.com/velour/catbase/bot"
|
||||||
"github.com/velour/catbase/bot/msg"
|
"github.com/velour/catbase/bot/msg"
|
||||||
|
@ -75,9 +76,11 @@ func New(b bot.Bot) *TwitchPlugin {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
go p.twitchLoop(b.DefaultConnector(), ch)
|
go p.twitchChannelLoop(b.DefaultConnector(), ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go p.twitchAuthLoop(b.DefaultConnector())
|
||||||
|
|
||||||
b.Register(p, bot.Message, p.message)
|
b.Register(p, bot.Message, p.message)
|
||||||
b.Register(p, bot.Help, p.help)
|
b.Register(p, bot.Help, p.help)
|
||||||
p.registerWeb()
|
p.registerWeb()
|
||||||
|
@ -128,7 +131,10 @@ func (p *TwitchPlugin) message(c bot.Connector, kind bot.Kind, message msg.Messa
|
||||||
if users := p.config.GetArray("Twitch."+channel+".Users", []string{}); len(users) > 0 {
|
if users := p.config.GetArray("Twitch."+channel+".Users", []string{}); len(users) > 0 {
|
||||||
for _, twitcherName := range users {
|
for _, twitcherName := range users {
|
||||||
if _, ok := p.twitchList[twitcherName]; ok {
|
if _, ok := p.twitchList[twitcherName]; ok {
|
||||||
p.checkTwitch(c, channel, p.twitchList[twitcherName], true)
|
err := p.checkTwitch(c, channel, p.twitchList[twitcherName], true)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msgf("error in checking twitch")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,25 +159,52 @@ func (p *TwitchPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message,
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TwitchPlugin) twitchLoop(c bot.Connector, channel string) {
|
func (p *TwitchPlugin) twitchAuthLoop(c bot.Connector) {
|
||||||
|
frequency := p.config.GetInt("Twitch.AuthFreq", 60*60)
|
||||||
|
cid := p.config.Get("twitch.clientid", "")
|
||||||
|
secret := p.config.Get("twitch.secret", "")
|
||||||
|
if cid == "" || secret == "" {
|
||||||
|
log.Info().Msgf("Disabling twitch autoauth.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Msgf("Checking auth every %d seconds", frequency)
|
||||||
|
|
||||||
|
if err := p.validateCredentials(); err != nil {
|
||||||
|
log.Error().Err(err).Msgf("error checking twitch validity")
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Duration(frequency) * time.Second):
|
||||||
|
if err := p.validateCredentials(); err != nil {
|
||||||
|
log.Error().Err(err).Msgf("error checking twitch validity")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TwitchPlugin) twitchChannelLoop(c bot.Connector, channel string) {
|
||||||
frequency := p.config.GetInt("Twitch.Freq", 60)
|
frequency := p.config.GetInt("Twitch.Freq", 60)
|
||||||
if p.config.Get("twitch.clientid", "") == "" || p.config.Get("twitch.token", "") == "" {
|
if p.config.Get("twitch.clientid", "") == "" || p.config.Get("twitch.secret", "") == "" {
|
||||||
log.Info().Msgf("Disabling twitch autochecking.")
|
log.Info().Msgf("Disabling twitch autochecking.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Msgf("Checking every %d seconds", frequency)
|
log.Info().Msgf("Checking channels every %d seconds", frequency)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Duration(frequency) * time.Second)
|
time.Sleep(time.Duration(frequency) * time.Second)
|
||||||
|
|
||||||
for _, twitcherName := range p.config.GetArray("Twitch."+channel+".Users", []string{}) {
|
for _, twitcherName := range p.config.GetArray("Twitch."+channel+".Users", []string{}) {
|
||||||
p.checkTwitch(c, channel, p.twitchList[twitcherName], false)
|
if err := p.checkTwitch(c, channel, p.twitchList[twitcherName], false); err != nil {
|
||||||
|
log.Error().Err(err).Msgf("error in twitch loop")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRequest(url, clientID, token string) ([]byte, bool) {
|
func getRequest(url, clientID, token string) ([]byte, int, bool) {
|
||||||
bearer := fmt.Sprintf("Bearer %s", token)
|
bearer := fmt.Sprintf("Bearer %s", token)
|
||||||
var body []byte
|
var body []byte
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
|
@ -193,18 +226,19 @@ func getRequest(url, clientID, token string) ([]byte, bool) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
goto errCase
|
goto errCase
|
||||||
}
|
}
|
||||||
return body, true
|
return body, resp.StatusCode, true
|
||||||
|
|
||||||
errCase:
|
errCase:
|
||||||
log.Error().Err(err)
|
log.Error().Err(err)
|
||||||
return []byte{}, false
|
return []byte{}, resp.StatusCode, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Twitcher, alwaysPrintStatus bool) {
|
func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Twitcher, alwaysPrintStatus bool) error {
|
||||||
baseURL, err := url.Parse("https://api.twitch.tv/helix/streams")
|
baseURL, err := url.Parse("https://api.twitch.tv/helix/streams")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Msg("Error parsing twitch stream URL")
|
err := fmt.Errorf("error parsing twitch stream URL")
|
||||||
return
|
log.Error().Msg(err.Error())
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
query := baseURL.Query()
|
query := baseURL.Query()
|
||||||
|
@ -216,25 +250,28 @@ func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Tw
|
||||||
token := p.config.Get("twitch.token", "")
|
token := p.config.Get("twitch.token", "")
|
||||||
if cid == token && cid == "" {
|
if cid == token && cid == "" {
|
||||||
log.Info().Msgf("Twitch plugin not enabled.")
|
log.Info().Msgf("Twitch plugin not enabled.")
|
||||||
return
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
body, ok := getRequest(baseURL.String(), cid, token)
|
body, status, ok := getRequest(baseURL.String(), cid, token)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return fmt.Errorf("got status %d: %s", status, string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
var s stream
|
var s stream
|
||||||
err = json.Unmarshal(body, &s)
|
err = json.Unmarshal(body, &s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err)
|
log.Error().Err(err).Msgf("error reading twitch data")
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
games := s.Data
|
games := s.Data
|
||||||
gameID, title := "", ""
|
gameID, title := "", ""
|
||||||
if len(games) > 0 {
|
if len(games) > 0 {
|
||||||
gameID = games[0].GameID
|
gameID = games[0].GameID
|
||||||
|
if gameID == "" {
|
||||||
|
gameID = "unknown"
|
||||||
|
}
|
||||||
title = games[0].Title
|
title = games[0].Title
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,4 +335,47 @@ func (p *TwitchPlugin) checkTwitch(c bot.Connector, channel string, twitcher *Tw
|
||||||
}
|
}
|
||||||
twitcher.gameID = gameID
|
twitcher.gameID = gameID
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TwitchPlugin) validateCredentials() error {
|
||||||
|
cid := p.config.Get("twitch.clientid", "")
|
||||||
|
token := p.config.Get("twitch.token", "")
|
||||||
|
if token == "" {
|
||||||
|
return p.reAuthenticate()
|
||||||
|
}
|
||||||
|
_, status, ok := getRequest("https://id.twitch.tv/oauth2/validate", cid, token)
|
||||||
|
if !ok || status == http.StatusUnauthorized {
|
||||||
|
return p.reAuthenticate()
|
||||||
|
}
|
||||||
|
log.Debug().Msgf("checked credentials and they were valid")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TwitchPlugin) reAuthenticate() error {
|
||||||
|
cid := p.config.Get("twitch.clientid", "")
|
||||||
|
secret := p.config.Get("twitch.secret", "")
|
||||||
|
if cid == "" || secret == "" {
|
||||||
|
return fmt.Errorf("could not request a new token without config values set")
|
||||||
|
}
|
||||||
|
resp, err := http.PostForm("https://id.twitch.tv/oauth2/token", url.Values{
|
||||||
|
"client_id": {cid},
|
||||||
|
"client_secret": {secret},
|
||||||
|
"grant_type": {"client_credentials"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
credentials := struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}{}
|
||||||
|
err = json.Unmarshal(body, &credentials)
|
||||||
|
log.Debug().Int("expires", credentials.ExpiresIn).Msgf("setting new twitch token")
|
||||||
|
return p.config.RegisterSecret("twitch.token", credentials.AccessToken)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue