diff --git a/bot/bot.go b/bot/bot.go index b59c29c..a17e2ae 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -129,8 +129,11 @@ func (b *bot) ListenAndServe() { log.Fatal().Err(http.ListenAndServe(addr, b.router)).Msg("bot killed") } -func (b *bot) RegisterWeb(r http.Handler, root, name string) { - log.Debug().Msgf("registering %s at %s", name, root) +func (b *bot) RegisterWeb(r http.Handler, root string) { + b.router.Mount(root, r) +} + +func (b *bot) RegisterWebName(r http.Handler, root, name string) { b.httpEndPoints = append(b.httpEndPoints, EndPoint{name, root}) b.router.Mount(root, r) } @@ -391,6 +394,9 @@ func PluginName(p Plugin) string { } func (b *bot) CheckPassword(secret, password string) bool { + if password == "" { + return false + } if b.password == password { return true } diff --git a/bot/interfaces.go b/bot/interfaces.go index 531bb5c..64381e2 100644 --- a/bot/interfaces.go +++ b/bot/interfaces.go @@ -135,7 +135,10 @@ type Bot interface { RegisterFilter(string, func(string) string) // RegisterWeb records a web endpoint for the UI - RegisterWeb(http.Handler, string, string) + RegisterWebName(http.Handler, string, string) + + // RegisterWeb records a web endpoint for the API + RegisterWeb(http.Handler, string) // Start the HTTP service ListenAndServe() @@ -202,6 +205,9 @@ type Connector interface { // GetChannelName returns the channel ID for a human-friendly name (if possible) GetChannelID(id string) string + + // Get any web handlers the connector exposes + GetRouter() (http.Handler, string) } // Plugin interface used for compatibility with the Plugin interface diff --git a/bot/mock.go b/bot/mock.go index 8ee861d..9279fb2 100644 --- a/bot/mock.go +++ b/bot/mock.go @@ -57,7 +57,8 @@ func (mb *MockBot) Register(p Plugin, kind Kind, cb Callback) func (mb *MockBot) RegisterTable(p Plugin, hs HandlerTable) {} func (mb *MockBot) RegisterRegex(p Plugin, kind Kind, r *regexp.Regexp, h ResponseHandler) {} func (mb *MockBot) RegisterRegexCmd(p Plugin, kind Kind, r *regexp.Regexp, h ResponseHandler) {} -func (mb *MockBot) RegisterWeb(_, _ string) {} +func (mb *MockBot) RegisterWebName(_ http.Handler, _, _ string) {} +func (mb *MockBot) RegisterWeb(_ http.Handler, _ string) {} func (mb *MockBot) GetWebNavigation() []EndPoint { return nil } func (mb *MockBot) Receive(c Connector, kind Kind, msg msg.Message, args ...interface{}) bool { return false @@ -124,3 +125,4 @@ func (mb *MockBot) GetWhitelist() []string { return []string func (mb *MockBot) OnBlacklist(ch, p string) bool { return false } func (mb *MockBot) URLFormat(title, url string) string { return title + url } func (mb *MockBot) CheckPassword(secret, password string) bool { return true } +func (mb *MockBot) ListenAndServe() {} diff --git a/connectors/discord/discord.go b/connectors/discord/discord.go index 2694f16..775fff1 100644 --- a/connectors/discord/discord.go +++ b/connectors/discord/discord.go @@ -3,6 +3,7 @@ package discord import ( "errors" "fmt" + "net/http" "strings" "github.com/velour/catbase/bot/msg" @@ -34,6 +35,9 @@ func New(config *config.Config) *Discord { } return d } +func (d *Discord) GetRouter() (http.Handler, string) { + return nil, "" +} func (d *Discord) RegisterEvent(callback bot.Callback) { d.event = callback diff --git a/connectors/irc/irc.go b/connectors/irc/irc.go index 800146a..c329a27 100644 --- a/connectors/irc/irc.go +++ b/connectors/irc/irc.go @@ -5,6 +5,7 @@ package irc import ( "fmt" "io" + "net/http" "os" "strings" "time" @@ -52,6 +53,10 @@ func New(c *config.Config) *Irc { return &i } +func (i *Irc) GetRouter() (http.Handler, string) { + return nil, "" +} + func (i *Irc) RegisterEvent(f bot.Callback) { i.event = f } diff --git a/connectors/slack/slack.go b/connectors/slack/slack.go index 1e30c8d..377b91e 100644 --- a/connectors/slack/slack.go +++ b/connectors/slack/slack.go @@ -174,6 +174,10 @@ func New(c *config.Config) *Slack { } } +func (s *Slack) GetRouter() (http.Handler, string) { + return nil, "" +} + func (s *Slack) Send(kind bot.Kind, args ...interface{}) (string, error) { switch kind { case bot.Message: diff --git a/connectors/slackapp/slackApp.go b/connectors/slackapp/slackApp.go index 09d4b40..8a12ab2 100644 --- a/connectors/slackapp/slackApp.go +++ b/connectors/slackapp/slackApp.go @@ -17,6 +17,7 @@ import ( "text/template" "time" + "github.com/go-chi/chi/v5" zerowidth "github.com/trubitsyn/go-zero-width" "github.com/rs/zerolog/log" @@ -44,6 +45,7 @@ const defaultLogFormat = "[{{fixDate .Time \"2006-01-02 15:04:05\"}}] {{if .Topi type SlackApp struct { config *config.Config api *slack.Client + router *chi.Mux botToken string userToken string @@ -84,6 +86,7 @@ func New(c *config.Config) *SlackApp { return &SlackApp{ api: api, + router: chi.NewRouter(), config: c, botToken: token, userToken: c.Get("slack.usertoken", "NONE"), @@ -103,10 +106,14 @@ func (s *SlackApp) RegisterEvent(f bot.Callback) { s.event = f } +func (s *SlackApp) GetRouter() (http.Handler, string) { + return s.router, "/evt" +} + func (s *SlackApp) Serve() error { s.populateEmojiList() - http.HandleFunc("/evt", func(w http.ResponseWriter, r *http.Request) { + s.router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { buf := new(bytes.Buffer) buf.ReadFrom(r.Body) body := buf.String() diff --git a/main.go b/main.go index c508924..5758cbe 100644 --- a/main.go +++ b/main.go @@ -120,6 +120,10 @@ func main() { b := bot.New(c, client) + if r, path := client.GetRouter(); r != nil { + b.RegisterWeb(r, path) + } + b.AddPlugin(admin.New(b)) b.AddPlugin(secrets.New(b)) b.AddPlugin(giphy.New(b)) diff --git a/plugins/admin/web.go b/plugins/admin/web.go index de08bf6..00d8c50 100644 --- a/plugins/admin/web.go +++ b/plugins/admin/web.go @@ -18,12 +18,12 @@ func (p *AdminPlugin) registerWeb() { r := chi.NewRouter() r.HandleFunc("/api", p.handleVarsAPI) r.HandleFunc("/", p.handleVars) - p.bot.RegisterWeb(r, "/vars", "Variables") + p.bot.RegisterWebName(r, "/vars", "Variables") r = chi.NewRouter() r.HandleFunc("/verify", p.handleAppPassCheck) r.HandleFunc("/api", p.handleAppPassAPI) r.HandleFunc("/", p.handleAppPass) - p.bot.RegisterWeb(r, "/apppass", "App Pass") + p.bot.RegisterWebName(r, "/apppass", "App Pass") } func (p *AdminPlugin) handleAppPass(w http.ResponseWriter, r *http.Request) { diff --git a/plugins/beers/beers.go b/plugins/beers/beers.go index c600ff2..f3facb3 100644 --- a/plugins/beers/beers.go +++ b/plugins/beers/beers.go @@ -19,6 +19,7 @@ import ( "strings" "time" + "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/nfnt/resize" @@ -589,7 +590,9 @@ func (p *BeersPlugin) untappdLoop(c bot.Connector, channel string) { } func (p *BeersPlugin) registerWeb() { - http.HandleFunc("/beers/img/", p.img) + r := chi.NewRouter() + r.HandleFunc("/img", p.img) + p.b.RegisterWeb(r, "/beers") } func (p *BeersPlugin) img(w http.ResponseWriter, r *http.Request) { diff --git a/plugins/cli/cli.go b/plugins/cli/cli.go index 16bcc68..7d98216 100644 --- a/plugins/cli/cli.go +++ b/plugins/cli/cli.go @@ -36,7 +36,11 @@ func (p *CliPlugin) registerWeb() { r := chi.NewRouter() r.HandleFunc("/api", p.handleWebAPI) r.HandleFunc("/", p.handleWeb) - p.bot.RegisterWeb(r, "/cli", "CLI") + p.bot.RegisterWebName(r, "/cli", "CLI") +} + +func (p *CliPlugin) GetRouter() (http.Handler, string) { + return nil, "" } func (p *CliPlugin) handleWebAPI(w http.ResponseWriter, r *http.Request) { diff --git a/plugins/counter/api.go b/plugins/counter/api.go index 66cab87..d5ef074 100644 --- a/plugins/counter/api.go +++ b/plugins/counter/api.go @@ -10,7 +10,6 @@ import ( "github.com/rs/zerolog/log" "github.com/velour/catbase/bot" "github.com/velour/catbase/bot/msg" - "github.com/velour/catbase/bot/user" ) func (p *CounterPlugin) registerWeb() { @@ -19,19 +18,53 @@ func (p *CounterPlugin) registerWeb() { r.HandleFunc("/api/users/{user}/items/{item}/decrement", p.mkIncrementAPI(-1)) r.HandleFunc("/api", p.handleCounterAPI) r.HandleFunc("/", p.handleCounter) - p.b.RegisterWeb(r, "/counter", "Counter") + p.b.RegisterWebName(r, "/counter", "Counter") } func (p *CounterPlugin) mkIncrementAPI(delta int) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - vars := map[string]string{} - userName := vars["user"] - itemName := vars["item"] - item, err := GetUserItem(p.db, userName, "", itemName) - if err != nil { + userName := chi.URLParam(r, "user") + itemName := chi.URLParam(r, "item") + + secret, pass, ok := r.BasicAuth() + if !ok || !p.b.CheckPassword(secret, pass) { + err := fmt.Errorf("unauthorized access") + log.Error(). + Err(err). + Msg("error authenticating user") + w.WriteHeader(401) + j, _ := json.Marshal(struct { + Status bool + Error string + }{false, err.Error()}) + fmt.Fprint(w, string(j)) + return + } + + // Try to find an ID if possible + u, err := p.b.DefaultConnector().Profile(userName) + if err != nil { + log.Error().Err(err).Msg("error finding user") + w.WriteHeader(400) + j, _ := json.Marshal(struct { + Status bool + Error error + }{false, err}) + fmt.Fprint(w, string(j)) + return + } + + item, err := GetUserItem(p.db, userName, u.ID, itemName) + if err != nil { + log.Error().Err(err).Msg("error finding item") + w.WriteHeader(400) + j, _ := json.Marshal(struct { + Status bool + Error error + }{false, err}) + fmt.Fprint(w, string(j)) return } - u := user.New(userName) req := &bot.Request{ Conn: p.b.DefaultConnector(), Kind: bot.Message, @@ -45,8 +78,8 @@ func (p *CounterPlugin) mkIncrementAPI(delta int) func(w http.ResponseWriter, r Args: nil, } item.UpdateDelta(req, delta) - msg := fmt.Sprintf("%s changed their %s counter by %d via the amazing %s API", - userName, itemName, delta, p.cfg.Get("nick", "catbase")) + msg := fmt.Sprintf("%s changed their %s counter by %d for a total of %d via the amazing %s API", + userName, itemName, delta, item.Count, p.cfg.Get("nick", "catbase")) for _, ch := range p.cfg.GetArray("channels", []string{}) { p.b.Send(p.b.DefaultConnector(), bot.Message, ch, msg) } diff --git a/plugins/counter/counter.go b/plugins/counter/counter.go index dcc87a8..eedc918 100644 --- a/plugins/counter/counter.go +++ b/plugins/counter/counter.go @@ -321,8 +321,9 @@ func New(b bot.Bot) *CounterPlugin { } cp := &CounterPlugin{ - b: b, - db: b.DB(), + b: b, + db: b.DB(), + cfg: b.Config(), } b.RegisterRegex(cp, bot.Startup, regexp.MustCompile(`.*`), cp.migrate) diff --git a/plugins/fact/factoid.go b/plugins/fact/factoid.go index 0794f1d..9b493f7 100644 --- a/plugins/fact/factoid.go +++ b/plugins/fact/factoid.go @@ -806,7 +806,7 @@ func (p *FactoidPlugin) registerWeb() { r.HandleFunc("/api", p.serveAPI) r.HandleFunc("/req", p.serveQuery) r.HandleFunc("/", p.serveQuery) - p.Bot.RegisterWeb(r, "/factoid", "Factoid") + p.Bot.RegisterWebName(r, "/factoid", "Factoid") } func linkify(text string) template.HTML { diff --git a/plugins/git/git.go b/plugins/git/git.go index a0b00bc..0b8123b 100644 --- a/plugins/git/git.go +++ b/plugins/git/git.go @@ -2,9 +2,9 @@ package git import ( "fmt" - "net/http" "regexp" + "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" "gopkg.in/go-playground/webhooks.v5/github" "gopkg.in/go-playground/webhooks.v5/gitlab" @@ -77,7 +77,9 @@ func (p *GitPlugin) register() { } func (p *GitPlugin) registerWeb() { - http.HandleFunc("/git/gitea/event", p.giteaEvent) - http.HandleFunc("/git/github/event", p.githubEvent) - http.HandleFunc("/git/gitlab/event", p.gitlabEvent) + r := chi.NewRouter() + r.HandleFunc("/gitea/event", p.giteaEvent) + r.HandleFunc("/github/event", p.githubEvent) + r.HandleFunc("/gitlab/event", p.gitlabEvent) + p.b.RegisterWeb(r, "/git") } diff --git a/plugins/meme/meme.go b/plugins/meme/meme.go index 223c702..232507b 100644 --- a/plugins/meme/meme.go +++ b/plugins/meme/meme.go @@ -361,6 +361,9 @@ func (p *MemePlugin) checkMeme(imgURL string) (int, int, error) { } img, err := DownloadTemplate(u) + if err != nil { + return 0, 0, err + } return img.Bounds().Dx(), img.Bounds().Dy(), err } diff --git a/plugins/meme/webHandlers.go b/plugins/meme/webHandlers.go index 8ab0cdd..b46ea6d 100644 --- a/plugins/meme/webHandlers.go +++ b/plugins/meme/webHandlers.go @@ -21,7 +21,7 @@ func (p *MemePlugin) registerWeb(c bot.Connector) { r.HandleFunc("/add", p.addMeme) r.HandleFunc("/rm", p.rmMeme) r.HandleFunc("/", p.webRoot) - p.bot.RegisterWeb(r, "/meme", "Memes") + p.bot.RegisterWebName(r, "/meme", "Memes") } type webResp struct { diff --git a/plugins/secrets/secrets.go b/plugins/secrets/secrets.go index e5dbc05..0fa9a4c 100644 --- a/plugins/secrets/secrets.go +++ b/plugins/secrets/secrets.go @@ -39,7 +39,7 @@ func (p *SecretsPlugin) registerWeb() { w.Write(j) }) r.HandleFunc("/", p.handleIndex) - p.b.RegisterWeb(r, "/secrets", "Secrets") + p.b.RegisterWebName(r, "/secrets", "Secrets") } func (p *SecretsPlugin) registerSecret(key, value string) error { diff --git a/plugins/sms/sms.go b/plugins/sms/sms.go index fbc00b8..9f81e7b 100644 --- a/plugins/sms/sms.go +++ b/plugins/sms/sms.go @@ -6,6 +6,7 @@ import ( "regexp" "strings" + "github.com/go-chi/chi/v5" twilio "github.com/kevinburke/twilio-go" "github.com/rs/zerolog/log" @@ -118,7 +119,9 @@ func (p *SMSPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, ar } func (p *SMSPlugin) registerWeb() { - http.HandleFunc("/sms/new", p.receive) + r := chi.NewRouter() + r.HandleFunc("/new", p.receive) + p.b.RegisterWeb(r, "/sms") } func (p *SMSPlugin) receive(w http.ResponseWriter, r *http.Request) { diff --git a/plugins/talker/talker.go b/plugins/talker/talker.go index 4f35a06..45d95e7 100644 --- a/plugins/talker/talker.go +++ b/plugins/talker/talker.go @@ -10,6 +10,7 @@ import ( "os/exec" "strings" + "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" "github.com/velour/catbase/bot" @@ -171,7 +172,8 @@ func (p *TalkerPlugin) allCows() []string { } func (p *TalkerPlugin) registerWeb(c bot.Connector) { - http.HandleFunc("/slash/cowsay", func(w http.ResponseWriter, r *http.Request) { + r := chi.NewRouter() + r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.ParseForm() log.Debug().Msgf("Cowsay:\n%+v", r.PostForm.Get("text")) channel := r.PostForm.Get("channel_id") @@ -184,4 +186,5 @@ func (p *TalkerPlugin) registerWeb(c bot.Connector) { p.bot.Send(c, bot.Message, channel, msg) w.WriteHeader(200) }) + p.bot.RegisterWeb(r, "/cowsay") } diff --git a/plugins/twitch/twitch.go b/plugins/twitch/twitch.go index 155fa9f..e82c6f7 100644 --- a/plugins/twitch/twitch.go +++ b/plugins/twitch/twitch.go @@ -11,6 +11,7 @@ import ( "text/template" "time" + "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" "github.com/velour/catbase/bot" "github.com/velour/catbase/bot/msg" @@ -85,7 +86,9 @@ func New(b bot.Bot) *TwitchPlugin { } func (p *TwitchPlugin) registerWeb() { - http.HandleFunc("/isstreaming/", p.serveStreaming) + r := chi.NewRouter() + r.HandleFunc("/", p.serveStreaming) + p.bot.RegisterWeb(r, "/isstreaming") } func (p *TwitchPlugin) serveStreaming(w http.ResponseWriter, r *http.Request) {