web: migrate endpoint registration to chi

This commit is contained in:
Chris Sexton 2021-07-21 14:52:45 -04:00 committed by Chris Sexton
parent 31fe84b0ba
commit c47a4f7c6f
17 changed files with 260 additions and 178 deletions

View File

@ -11,6 +11,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/velour/catbase/bot/msg" "github.com/velour/catbase/bot/msg"
@ -59,6 +61,8 @@ type bot struct {
passwordCreated time.Time passwordCreated time.Time
quiet bool quiet bool
router *chi.Mux
} }
type EndPoint struct { type EndPoint struct {
@ -98,6 +102,7 @@ func New(config *config.Config, connector Connector) Bot {
httpEndPoints: make([]EndPoint, 0), httpEndPoints: make([]EndPoint, 0),
filters: make(map[string]func(string) string), filters: make(map[string]func(string) string),
callbacks: make(CallbackMap), callbacks: make(CallbackMap),
router: chi.NewRouter(),
} }
bot.migrateDB() bot.migrateDB()
@ -105,14 +110,31 @@ func New(config *config.Config, connector Connector) Bot {
bot.RefreshPluginBlacklist() bot.RefreshPluginBlacklist()
bot.RefreshPluginWhitelist() bot.RefreshPluginWhitelist()
http.HandleFunc("/", bot.serveRoot) log.Debug().Msgf("created web router")
http.HandleFunc("/nav", bot.serveNav)
bot.router.Use(middleware.Logger)
bot.router.Use(middleware.StripSlashes)
bot.router.HandleFunc("/", bot.serveRoot)
bot.router.HandleFunc("/nav", bot.serveNav)
connector.RegisterEvent(bot.Receive) connector.RegisterEvent(bot.Receive)
return bot return bot
} }
func (b *bot) ListenAndServe() {
addr := b.config.Get("HttpAddr", "127.0.0.1:1337")
log.Debug().Msgf("starting web service at %s", addr)
log.Fatal().Err(http.ListenAndServe(addr, b.router)).Msg("bot killed")
}
func (b *bot) RegisterWeb(r http.Handler, root, name string) {
log.Debug().Msgf("registering %s at %s", name, root)
b.httpEndPoints = append(b.httpEndPoints, EndPoint{name, root})
b.router.Mount(root, r)
}
// DefaultConnector is the main connector used for the bot // DefaultConnector is the main connector used for the bot
// If more than one connector is on, some users may not see all messages if this is used. // If more than one connector is on, some users may not see all messages if this is used.
// Usage should be limited to out-of-band communications such as timed messages. // Usage should be limited to out-of-band communications such as timed messages.
@ -143,19 +165,19 @@ func (b *bot) migrateDB() {
name string, name string,
value string value string
);`); err != nil { );`); err != nil {
log.Fatal().Err(err).Msgf("Initial DB migration create variables table") log.Fatal().Err(err).Msgf("Initial db migration create variables table")
} }
if _, err := b.DB().Exec(`create table if not exists pluginBlacklist ( if _, err := b.DB().Exec(`create table if not exists pluginBlacklist (
channel string, channel string,
name string, name string,
primary key (channel, name) primary key (channel, name)
);`); err != nil { );`); err != nil {
log.Fatal().Err(err).Msgf("Initial DB migration create blacklist table") log.Fatal().Err(err).Msgf("Initial db migration create blacklist table")
} }
if _, err := b.DB().Exec(`create table if not exists pluginWhitelist ( if _, err := b.DB().Exec(`create table if not exists pluginWhitelist (
name string primary key name string primary key
);`); err != nil { );`); err != nil {
log.Fatal().Err(err).Msgf("Initial DB migration create whitelist table") log.Fatal().Err(err).Msgf("Initial db migration create whitelist table")
} }
} }
@ -280,10 +302,6 @@ func (b *bot) Register(p Plugin, kind Kind, cb Callback) {
b.RegisterRegex(p, kind, r, resp) b.RegisterRegex(p, kind, r, resp)
} }
func (b *bot) RegisterWeb(root, name string) {
b.httpEndPoints = append(b.httpEndPoints, EndPoint{name, root})
}
// GetPassword returns a random password generated for the bot // GetPassword returns a random password generated for the bot
// Passwords expire in 24h and are used for the web interface // Passwords expire in 24h and are used for the web interface
func (b *bot) GetPassword() string { func (b *bot) GetPassword() string {

View File

@ -3,6 +3,7 @@
package bot package bot
import ( import (
"net/http"
"regexp" "regexp"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -134,7 +135,10 @@ type Bot interface {
RegisterFilter(string, func(string) string) RegisterFilter(string, func(string) string)
// RegisterWeb records a web endpoint for the UI // RegisterWeb records a web endpoint for the UI
RegisterWeb(string, string) RegisterWeb(http.Handler, string, string)
// Start the HTTP service
ListenAndServe()
// DefaultConnector returns the base connector, which may not be the only connector // DefaultConnector returns the base connector, which may not be the only connector
DefaultConnector() Connector DefaultConnector() Connector

View File

@ -36,7 +36,7 @@ type Secret struct {
// GetFloat64 returns the config value for a string key // GetFloat64 returns the config value for a string key
// It will first look in the env vars for the key // It will first look in the env vars for the key
// It will check the DB for the key if an env DNE // It will check the db for the key if an env DNE
// Finally, it will return a zero value if the key does not exist // Finally, it will return a zero value if the key does not exist
// It will attempt to convert the value to a float64 if it exists // It will attempt to convert the value to a float64 if it exists
func (c *Config) GetFloat64(key string, fallback float64) float64 { func (c *Config) GetFloat64(key string, fallback float64) float64 {
@ -49,7 +49,7 @@ func (c *Config) GetFloat64(key string, fallback float64) float64 {
// GetInt64 returns the config value for a string key // GetInt64 returns the config value for a string key
// It will first look in the env vars for the key // It will first look in the env vars for the key
// It will check the DB for the key if an env DNE // It will check the db for the key if an env DNE
// Finally, it will return a zero value if the key does not exist // Finally, it will return a zero value if the key does not exist
// It will attempt to convert the value to an int if it exists // It will attempt to convert the value to an int if it exists
func (c *Config) GetInt64(key string, fallback int64) int64 { func (c *Config) GetInt64(key string, fallback int64) int64 {
@ -62,7 +62,7 @@ func (c *Config) GetInt64(key string, fallback int64) int64 {
// GetInt returns the config value for a string key // GetInt returns the config value for a string key
// It will first look in the env vars for the key // It will first look in the env vars for the key
// It will check the DB for the key if an env DNE // It will check the db for the key if an env DNE
// Finally, it will return a zero value if the key does not exist // Finally, it will return a zero value if the key does not exist
// It will attempt to convert the value to an int if it exists // It will attempt to convert the value to an int if it exists
func (c *Config) GetInt(key string, fallback int) int { func (c *Config) GetInt(key string, fallback int) int {
@ -86,7 +86,7 @@ func envkey(key string) string {
// GetString returns the config value for a string key // GetString returns the config value for a string key
// It will first look in the env vars for the key // It will first look in the env vars for the key
// It will check the DB for the key if an env DNE // It will check the db for the key if an env DNE
// Finally, it will return a zero value if the key does not exist // Finally, it will return a zero value if the key does not exist
// It will convert the value to a string if it exists // It will convert the value to a string if it exists
func (c *Config) GetString(key, fallback string) string { func (c *Config) GetString(key, fallback string) string {
@ -124,7 +124,7 @@ func (c *Config) GetMap(key string, fallback map[string]string) map[string]strin
// GetArray returns the string slice config value for a string key // GetArray returns the string slice config value for a string key
// It will first look in the env vars for the key with ;; separated values // It will first look in the env vars for the key with ;; separated values
// Look, I'm too lazy to do parsing to ensure that a comma is what the user meant // Look, I'm too lazy to do parsing to ensure that a comma is what the user meant
// It will check the DB for the key if an env DNE // It will check the db for the key if an env DNE
// Finally, it will return a zero value if the key does not exist // Finally, it will return a zero value if the key does not exist
// This will do no conversion. // This will do no conversion.
func (c *Config) GetArray(key string, fallback []string) []string { func (c *Config) GetArray(key string, fallback []string) []string {

View File

@ -24,7 +24,7 @@ type Discord struct {
} }
func New(config *config.Config) *Discord { func New(config *config.Config) *Discord {
client, err := discordgo.New("Bot " + config.Get("DISCORDBOTTOKEN", "")) client, err := discordgo.New("b " + config.Get("DISCORDBOTTOKEN", ""))
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Could not connect to Discord") log.Fatal().Err(err).Msg("Could not connect to Discord")
} }

View File

@ -696,7 +696,6 @@ func (s *SlackApp) Profile(identifier string) (user.User, error) {
} }
for _, u := range users { for _, u := range users {
log.Debug().Str("Name", u.Name).Str("ID", u.ID).Msgf("Looking for %s", identifier)
if u.Name == identifier || u.ID == identifier { if u.Name == identifier || u.ID == identifier {
return user.User{ return user.User{
ID: u.ID, ID: u.ID,

1
go.mod
View File

@ -22,6 +22,7 @@ require (
github.com/dustin/gojson v0.0.0-20160307161227-2e71ec9dd5ad // indirect github.com/dustin/gojson v0.0.0-20160307161227-2e71ec9dd5ad // indirect
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90 github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90
github.com/garyburd/go-oauth v0.0.0-20180319155456-bca2e7f09a17 // indirect github.com/garyburd/go-oauth v0.0.0-20180319155456-bca2e7f09a17 // indirect
github.com/go-chi/chi/v5 v5.0.3
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/gobwas/glob v0.2.3 // indirect github.com/gobwas/glob v0.2.3 // indirect
github.com/gocolly/colly v1.2.0 github.com/gocolly/colly v1.2.0

2
go.sum
View File

@ -45,6 +45,8 @@ github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90 h1:WXb3TSNmHp2vHoCro
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/garyburd/go-oauth v0.0.0-20180319155456-bca2e7f09a17 h1:GOfMz6cRgTJ9jWV0qAezv642OhPnKEG7gtUjJSdStHE= github.com/garyburd/go-oauth v0.0.0-20180319155456-bca2e7f09a17 h1:GOfMz6cRgTJ9jWV0qAezv642OhPnKEG7gtUjJSdStHE=
github.com/garyburd/go-oauth v0.0.0-20180319155456-bca2e7f09a17/go.mod h1:HfkOCN6fkKKaPSAeNq/er3xObxTW4VLeY6UUK895gLQ= github.com/garyburd/go-oauth v0.0.0-20180319155456-bca2e7f09a17/go.mod h1:HfkOCN6fkKKaPSAeNq/er3xObxTW4VLeY6UUK895gLQ=
github.com/go-chi/chi/v5 v5.0.3 h1:khYQBdPivkYG1s1TAzDQG1f6eX4kD2TItYVZexL5rS4=
github.com/go-chi/chi/v5 v5.0.3/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=

View File

@ -6,7 +6,6 @@ import (
"flag" "flag"
"io" "io"
"math/rand" "math/rand"
"net/http"
"os" "os"
"time" "time"
@ -67,7 +66,7 @@ import (
var ( var (
key = flag.String("set", "", "Configuration key to set") key = flag.String("set", "", "Configuration key to set")
val = flag.String("val", "", "Configuration value to set") val = flag.String("val", "", "Configuration value to set")
initDB = flag.Bool("init", false, "Initialize the configuration DB") initDB = flag.Bool("init", false, "Initialize the configuration db")
prettyLog = flag.Bool("pretty", false, "Use pretty console logger") prettyLog = flag.Bool("pretty", false, "Use pretty console logger")
debug = flag.Bool("debug", false, "Turn on debug logging") debug = flag.Bool("debug", false, "Turn on debug logging")
) )
@ -169,7 +168,5 @@ func main() {
b.Receive(client, bot.Startup, msg.Message{}) b.Receive(client, bot.Startup, msg.Message{})
addr := c.Get("HttpAddr", "127.0.0.1:1337") b.ListenAndServe()
log.Debug().Msgf("starting web service at %s", addr)
log.Fatal().Err(http.ListenAndServe(addr, nil)).Msg("bot killed")
} }

View File

@ -9,18 +9,21 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func (p *AdminPlugin) registerWeb() { func (p *AdminPlugin) registerWeb() {
http.HandleFunc("/vars/api", p.handleVarsAPI) r := chi.NewRouter()
http.HandleFunc("/vars", p.handleVars) r.HandleFunc("/api", p.handleVarsAPI)
p.bot.RegisterWeb("/vars", "Variables") r.HandleFunc("/", p.handleVars)
http.HandleFunc("/apppass/verify", p.handleAppPassCheck) p.bot.RegisterWeb(r, "/vars", "Variables")
http.HandleFunc("/apppass/api", p.handleAppPassAPI) r = chi.NewRouter()
http.HandleFunc("/apppass", p.handleAppPass) r.HandleFunc("/verify", p.handleAppPassCheck)
p.bot.RegisterWeb("/apppass", "App Pass") r.HandleFunc("/api", p.handleAppPassAPI)
r.HandleFunc("/", p.handleAppPass)
p.bot.RegisterWeb(r, "/apppass", "App Pass")
} }
func (p *AdminPlugin) handleAppPass(w http.ResponseWriter, r *http.Request) { func (p *AdminPlugin) handleAppPass(w http.ResponseWriter, r *http.Request) {

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/go-chi/chi/v5"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -32,9 +33,10 @@ func New(b bot.Bot) *CliPlugin {
} }
func (p *CliPlugin) registerWeb() { func (p *CliPlugin) registerWeb() {
http.HandleFunc("/cli/api", p.handleWebAPI) r := chi.NewRouter()
http.HandleFunc("/cli", p.handleWeb) r.HandleFunc("/api", p.handleWebAPI)
p.bot.RegisterWeb("/cli", "CLI") r.HandleFunc("/", p.handleWeb)
p.bot.RegisterWeb(r, "/cli", "CLI")
} }
func (p *CliPlugin) handleWebAPI(w http.ResponseWriter, r *http.Request) { func (p *CliPlugin) handleWebAPI(w http.ResponseWriter, r *http.Request) {

128
plugins/counter/api.go Normal file
View File

@ -0,0 +1,128 @@
package counter
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/bot/msg"
"github.com/velour/catbase/bot/user"
)
func (p *CounterPlugin) registerWeb() {
r := chi.NewRouter()
r.HandleFunc("/api/users/{user}/items/{item}/increment", p.mkIncrementAPI(1))
r.HandleFunc("/api/users/{user}/items/{item}/decrement", p.mkIncrementAPI(-1))
r.HandleFunc("/api", p.handleCounterAPI)
r.HandleFunc("/", p.handleCounter)
p.b.RegisterWeb(r, "/counter", "Counter")
}
func (p *CounterPlugin) mkIncrementAPI(delta int) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
vars := map[string]string{}
userName := vars["user"]
itemName := vars["item"]
item, err := GetUserItem(p.db, userName, "", itemName)
if err != nil {
return
}
u := user.New(userName)
req := &bot.Request{
Conn: p.b.DefaultConnector(),
Kind: bot.Message,
Msg: msg.Message{
User: &u,
ChannelName: "#API",
Body: fmt.Sprintf("%s += %d", itemName, delta),
Time: time.Now(),
},
Values: nil,
Args: nil,
}
item.UpdateDelta(req, delta)
msg := fmt.Sprintf("%s changed their %s counter by %d via the amazing %s API",
userName, itemName, delta, p.cfg.Get("nick", "catbase"))
for _, ch := range p.cfg.GetArray("channels", []string{}) {
p.b.Send(p.b.DefaultConnector(), bot.Message, ch, msg)
}
j, _ := json.Marshal(struct{ Status bool }{true})
fmt.Fprint(w, string(j))
}
}
func (p *CounterPlugin) handleCounter(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, html)
}
func (p *CounterPlugin) handleCounterAPI(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
info := struct {
User string
Thing string
Action string
Password string
}{}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&info)
if err != nil {
w.WriteHeader(500)
fmt.Fprint(w, err)
return
}
log.Debug().
Interface("postbody", info).
Msg("Got a POST")
if p.b.CheckPassword("", info.Password) {
w.WriteHeader(http.StatusForbidden)
j, _ := json.Marshal(struct{ Err string }{Err: "Invalid Password"})
w.Write(j)
return
}
nick, id := p.resolveUser(bot.Request{Conn: p.b.DefaultConnector()}, info.User)
item, err := GetUserItem(p.db, nick, id, info.Thing)
if err != nil {
log.Error().
Err(err).
Str("subject", info.User).
Str("itemName", info.Thing).
Msg("error finding item")
w.WriteHeader(404)
fmt.Fprint(w, err)
return
}
if info.Action == "++" {
item.UpdateDelta(nil, 1)
} else if info.Action == "--" {
item.UpdateDelta(nil, -1)
} else {
w.WriteHeader(400)
fmt.Fprint(w, "Invalid increment")
return
}
}
all, err := GetAllItems(p.db)
if err != nil {
w.WriteHeader(500)
fmt.Fprint(w, err)
return
}
data, err := json.Marshal(all)
if err != nil {
w.WriteHeader(500)
fmt.Fprint(w, err)
return
}
fmt.Fprint(w, string(data))
}
type Update struct {
Who string
What string
Amount int
}

View File

@ -2,15 +2,14 @@ package counter
import ( import (
"database/sql" "database/sql"
"encoding/json"
"fmt" "fmt"
"math/rand" "math/rand"
"net/http"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/velour/catbase/config"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -21,8 +20,9 @@ import (
// This is a counter plugin to count arbitrary things. // This is a counter plugin to count arbitrary things.
type CounterPlugin struct { type CounterPlugin struct {
Bot bot.Bot b bot.Bot
DB *sqlx.DB db *sqlx.DB
cfg *config.Config
} }
type Item struct { type Item struct {
@ -50,7 +50,7 @@ func GetAllItems(db *sqlx.DB) ([]Item, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Don't forget to embed the DB into all of that shiz // Don't forget to embed the db into all of that shiz
for i := range items { for i := range items {
items[i].DB = db items[i].DB = db
} }
@ -69,7 +69,7 @@ func GetItems(db *sqlx.DB, nick, id string) ([]Item, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Don't forget to embed the DB into all of that shiz // Don't forget to embed the db into all of that shiz
for i := range items { for i := range items {
items[i].DB = db items[i].DB = db
} }
@ -254,7 +254,7 @@ func (i *Item) Delete() error {
} }
func (p *CounterPlugin) migrate(r bot.Request) bool { func (p *CounterPlugin) migrate(r bot.Request) bool {
db := p.DB db := p.db
nicks := []string{} nicks := []string{}
err := db.Select(&nicks, `select distinct nick from counter where userid is null`) err := db.Select(&nicks, `select distinct nick from counter where userid is null`)
@ -321,8 +321,8 @@ func New(b bot.Bot) *CounterPlugin {
} }
cp := &CounterPlugin{ cp := &CounterPlugin{
Bot: b, b: b,
DB: b.DB(), db: b.DB(),
} }
b.RegisterRegex(cp, bot.Startup, regexp.MustCompile(`.*`), cp.migrate) b.RegisterRegex(cp, bot.Startup, regexp.MustCompile(`.*`), cp.migrate)
@ -367,15 +367,15 @@ func (p *CounterPlugin) mkAliasCmd(r bot.Request) bool {
what := r.Values["what"] what := r.Values["what"]
to := r.Values["to"] to := r.Values["to"]
if what == "" || to == "" { if what == "" || to == "" {
p.Bot.Send(r.Conn, bot.Message, fmt.Sprintf("You must provide all fields for an alias: %s", mkAliasRegex)) p.b.Send(r.Conn, bot.Message, fmt.Sprintf("You must provide all fields for an alias: %s", mkAliasRegex))
return true return true
} }
if _, err := MkAlias(p.DB, what, to); err != nil { if _, err := MkAlias(p.db, what, to); err != nil {
log.Error().Err(err).Msg("Could not mkalias") log.Error().Err(err).Msg("Could not mkalias")
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, "We're gonna need too much DB space to make an alias for your mom.") p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "We're gonna need too much db space to make an alias for your mom.")
return true return true
} }
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Created alias %s -> %s", p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Created alias %s -> %s",
what, to)) what, to))
return true return true
} }
@ -383,15 +383,15 @@ func (p *CounterPlugin) mkAliasCmd(r bot.Request) bool {
func (p *CounterPlugin) rmAliasCmd(r bot.Request) bool { func (p *CounterPlugin) rmAliasCmd(r bot.Request) bool {
what := r.Values["what"] what := r.Values["what"]
if what == "" { if what == "" {
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, "You must specify an alias to remove.") p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "You must specify an alias to remove.")
return true return true
} }
if err := RmAlias(p.DB, what); err != nil { if err := RmAlias(p.db, what); err != nil {
log.Error().Err(err).Msg("could not RmAlias") log.Error().Err(err).Msg("could not RmAlias")
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, "`sudo rm your mom` => Nope, she's staying with me.") p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "`sudo rm your mom` => Nope, she's staying with me.")
return true return true
} }
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, "`sudo rm your mom`") p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "`sudo rm your mom`")
return true return true
} }
@ -401,10 +401,10 @@ func (p *CounterPlugin) leaderboardCmd(r bot.Request) bool {
what := r.Values["what"] what := r.Values["what"]
if what == "" { if what == "" {
cmd = func() ([]Item, error) { return LeaderAll(p.DB) } cmd = func() ([]Item, error) { return LeaderAll(p.db) }
} else { } else {
itNameTxt = fmt.Sprintf(" for %s", what) itNameTxt = fmt.Sprintf(" for %s", what)
cmd = func() ([]Item, error) { return Leader(p.DB, what) } cmd = func() ([]Item, error) { return Leader(p.db, what) }
} }
its, err := cmd() its, err := cmd()
@ -412,7 +412,7 @@ func (p *CounterPlugin) leaderboardCmd(r bot.Request) bool {
log.Error().Err(err).Msg("Error with leaderboard") log.Error().Err(err).Msg("Error with leaderboard")
return false return false
} else if len(its) == 0 { } else if len(its) == 0 {
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, "There are not enough entries for a leaderboard.") p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "There are not enough entries for a leaderboard.")
return true return true
} }
@ -424,7 +424,7 @@ func (p *CounterPlugin) leaderboardCmd(r bot.Request) bool {
it.Item, it.Item,
) )
} }
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, out) p.b.Send(r.Conn, bot.Message, r.Msg.Channel, out)
return true return true
} }
@ -432,20 +432,20 @@ func (p *CounterPlugin) resetCmd(r bot.Request) bool {
nick, id := p.resolveUser(r, "") nick, id := p.resolveUser(r, "")
channel := r.Msg.Channel channel := r.Msg.Channel
items, err := GetItems(p.DB, nick, id) items, err := GetItems(p.db, nick, id)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("nick", nick). Str("nick", nick).
Msg("Error getting items to reset") Msg("Error getting items to reset")
p.Bot.Send(r.Conn, bot.Message, channel, "Something is technically wrong with your counters.") p.b.Send(r.Conn, bot.Message, channel, "Something is technically wrong with your counters.")
return true return true
} }
log.Debug().Msgf("Items: %+v", items) log.Debug().Msgf("Items: %+v", items)
for _, item := range items { for _, item := range items {
item.Delete() item.Delete()
} }
p.Bot.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s, you are as new, my son.", nick)) p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s, you are as new, my son.", nick))
return true return true
} }
@ -466,14 +466,14 @@ func (p *CounterPlugin) inspectCmd(r bot.Request) bool {
Str("id", id). Str("id", id).
Msg("Getting counter") Msg("Getting counter")
// pull all of the items associated with "subject" // pull all of the items associated with "subject"
items, err := GetItems(p.DB, nick, id) items, err := GetItems(p.db, nick, id)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("nick", nick). Str("nick", nick).
Str("id", id). Str("id", id).
Msg("Error retrieving items") Msg("Error retrieving items")
p.Bot.Send(c, bot.Message, channel, "Something went wrong finding that counter;") p.b.Send(c, bot.Message, channel, "Something went wrong finding that counter;")
return true return true
} }
@ -493,11 +493,11 @@ func (p *CounterPlugin) inspectCmd(r bot.Request) bool {
resp += "." resp += "."
if count == 0 { if count == 0 {
p.Bot.Send(c, bot.Message, channel, fmt.Sprintf("%s has no counters.", nick)) p.b.Send(c, bot.Message, channel, fmt.Sprintf("%s has no counters.", nick))
return true return true
} }
p.Bot.Send(c, bot.Message, channel, resp) p.b.Send(c, bot.Message, channel, resp)
return true return true
} }
@ -507,7 +507,7 @@ func (p *CounterPlugin) clearCmd(r bot.Request) bool {
channel := r.Msg.Channel channel := r.Msg.Channel
c := r.Conn c := r.Conn
it, err := GetUserItem(p.DB, nick, id, itemName) it, err := GetUserItem(p.db, nick, id, itemName)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -515,7 +515,7 @@ func (p *CounterPlugin) clearCmd(r bot.Request) bool {
Str("id", id). Str("id", id).
Str("itemName", itemName). Str("itemName", itemName).
Msg("Error getting item to remove") Msg("Error getting item to remove")
p.Bot.Send(c, bot.Message, channel, "Something went wrong removing that counter;") p.b.Send(c, bot.Message, channel, "Something went wrong removing that counter;")
return true return true
} }
err = it.Delete() err = it.Delete()
@ -526,11 +526,11 @@ func (p *CounterPlugin) clearCmd(r bot.Request) bool {
Str("id", id). Str("id", id).
Str("itemName", itemName). Str("itemName", itemName).
Msg("Error removing item") Msg("Error removing item")
p.Bot.Send(c, bot.Message, channel, "Something went wrong removing that counter;") p.b.Send(c, bot.Message, channel, "Something went wrong removing that counter;")
return true return true
} }
p.Bot.Send(c, bot.Action, channel, fmt.Sprintf("chops a few %s out of his brain", p.b.Send(c, bot.Action, channel, fmt.Sprintf("chops a few %s out of his brain",
itemName)) itemName))
return true return true
} }
@ -545,10 +545,10 @@ func (p *CounterPlugin) countCmd(r bot.Request) bool {
} }
var item Item var item Item
item, err := GetUserItem(p.DB, nick, id, itemName) item, err := GetUserItem(p.db, nick, id, itemName)
switch { switch {
case err == sql.ErrNoRows: case err == sql.ErrNoRows:
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("I don't think %s has any %s.", p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("I don't think %s has any %s.",
nick, itemName)) nick, itemName))
return true return true
case err != nil: case err != nil:
@ -561,7 +561,7 @@ func (p *CounterPlugin) countCmd(r bot.Request) bool {
return true return true
} }
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("%s has %d %s.", nick, item.Count, p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("%s has %d %s.", nick, item.Count,
itemName)) itemName))
return true return true
@ -575,7 +575,7 @@ func (p *CounterPlugin) incrementCmd(r bot.Request) bool {
itemName := r.Values["thing"] itemName := r.Values["thing"]
channel := r.Msg.Channel channel := r.Msg.Channel
// ++ those fuckers // ++ those fuckers
item, err := GetUserItem(p.DB, nick, id, itemName) item, err := GetUserItem(p.db, nick, id, itemName)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -588,7 +588,7 @@ func (p *CounterPlugin) incrementCmd(r bot.Request) bool {
} }
log.Debug().Msgf("About to update item: %#v", item) log.Debug().Msgf("About to update item: %#v", item)
item.UpdateDelta(&r, 1) item.UpdateDelta(&r, 1)
p.Bot.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick,
item.Count, item.Item)) item.Count, item.Item))
return true return true
} }
@ -601,7 +601,7 @@ func (p *CounterPlugin) decrementCmd(r bot.Request) bool {
itemName := r.Values["thing"] itemName := r.Values["thing"]
channel := r.Msg.Channel channel := r.Msg.Channel
// -- those fuckers // -- those fuckers
item, err := GetUserItem(p.DB, nick, id, itemName) item, err := GetUserItem(p.db, nick, id, itemName)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -613,7 +613,7 @@ func (p *CounterPlugin) decrementCmd(r bot.Request) bool {
return false return false
} }
item.UpdateDelta(&r, -1) item.UpdateDelta(&r, -1)
p.Bot.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick,
item.Count, item.Item)) item.Count, item.Item))
return true return true
} }
@ -623,7 +623,7 @@ func (p *CounterPlugin) addToCmd(r bot.Request) bool {
itemName := r.Values["thing"] itemName := r.Values["thing"]
channel := r.Msg.Channel channel := r.Msg.Channel
// += those fuckers // += those fuckers
item, err := GetUserItem(p.DB, nick, id, itemName) item, err := GetUserItem(p.db, nick, id, itemName)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -637,7 +637,7 @@ func (p *CounterPlugin) addToCmd(r bot.Request) bool {
n, _ := strconv.Atoi(r.Values["amount"]) n, _ := strconv.Atoi(r.Values["amount"])
log.Debug().Msgf("About to update item by %d: %#v", n, item) log.Debug().Msgf("About to update item by %d: %#v", n, item)
item.UpdateDelta(&r, n) item.UpdateDelta(&r, n)
p.Bot.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick,
item.Count, item.Item)) item.Count, item.Item))
return true return true
} }
@ -647,7 +647,7 @@ func (p *CounterPlugin) removeFromCmd(r bot.Request) bool {
itemName := r.Values["thing"] itemName := r.Values["thing"]
channel := r.Msg.Channel channel := r.Msg.Channel
// -= those fuckers // -= those fuckers
item, err := GetUserItem(p.DB, nick, id, itemName) item, err := GetUserItem(p.db, nick, id, itemName)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@ -661,19 +661,19 @@ func (p *CounterPlugin) removeFromCmd(r bot.Request) bool {
n, _ := strconv.Atoi(r.Values["amount"]) n, _ := strconv.Atoi(r.Values["amount"])
log.Debug().Msgf("About to update item by -%d: %#v", n, item) log.Debug().Msgf("About to update item by -%d: %#v", n, item)
item.UpdateDelta(&r, -n) item.UpdateDelta(&r, -n)
p.Bot.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick,
item.Count, item.Item)) item.Count, item.Item))
return true return true
} }
// Help responds to help requests. Every plugin must implement a help function. // Help responds to help requests. Every plugin must implement a help function.
func (p *CounterPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool { func (p *CounterPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
p.Bot.Send(c, bot.Message, message.Channel, "You can set counters incrementally by using "+ p.b.Send(c, bot.Message, message.Channel, "You can set counters incrementally by using "+
"`<noun>++` and `<noun>--`. You can see all of your counters using "+ "`<noun>++` and `<noun>--`. You can see all of your counters using "+
"`inspect`, erase them with `clear`, and view single counters with "+ "`inspect`, erase them with `clear`, and view single counters with "+
"`count`.") "`count`.")
p.Bot.Send(c, bot.Message, message.Channel, "You can create aliases with `!mkalias <alias> <original>`") p.b.Send(c, bot.Message, message.Channel, "You can create aliases with `!mkalias <alias> <original>`")
p.Bot.Send(c, bot.Message, message.Channel, "You can remove aliases with `!rmalias <alias>`") p.b.Send(c, bot.Message, message.Channel, "You can remove aliases with `!rmalias <alias>`")
return true return true
} }
@ -688,7 +688,7 @@ func (p *CounterPlugin) teaMatchCmd(r bot.Request) bool {
itemName := strings.ToLower(submatches[1]) itemName := strings.ToLower(submatches[1])
// We will specifically allow :tea: to keep compatability // We will specifically allow :tea: to keep compatability
item, err := GetUserItem(p.DB, nick, id, itemName) item, err := GetUserItem(p.db, nick, id, itemName)
if err != nil || (item.Count == 0 && item.Item != ":tea:") { if err != nil || (item.Count == 0 && item.Item != ":tea:") {
log.Error(). log.Error().
Err(err). Err(err).
@ -703,7 +703,7 @@ func (p *CounterPlugin) teaMatchCmd(r bot.Request) bool {
delta = -1 delta = -1
} }
item.UpdateDelta(&r, delta) item.UpdateDelta(&r, delta)
p.Bot.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s... %s has %d %s", p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s... %s has %d %s",
strings.Join(everyDayImShuffling([]string{"bleep", "bloop", "blop"}), "-"), nick, item.Count, itemName)) strings.Join(everyDayImShuffling([]string{"bleep", "bloop", "blop"}), "-"), nick, item.Count, itemName))
return true return true
} }
@ -717,84 +717,6 @@ func everyDayImShuffling(vals []string) []string {
return ret return ret
} }
func (p *CounterPlugin) registerWeb() {
http.HandleFunc("/counter/api", p.handleCounterAPI)
http.HandleFunc("/counter", p.handleCounter)
p.Bot.RegisterWeb("/counter", "Counter")
}
func (p *CounterPlugin) handleCounter(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, html)
}
func (p *CounterPlugin) handleCounterAPI(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
info := struct {
User string
Thing string
Action string
Password string
}{}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&info)
if err != nil {
w.WriteHeader(500)
fmt.Fprint(w, err)
return
}
log.Debug().
Interface("postbody", info).
Msg("Got a POST")
if p.Bot.CheckPassword("", info.Password) {
w.WriteHeader(http.StatusForbidden)
j, _ := json.Marshal(struct{ Err string }{Err: "Invalid Password"})
w.Write(j)
return
}
nick, id := p.resolveUser(bot.Request{Conn: p.Bot.DefaultConnector()}, info.User)
item, err := GetUserItem(p.DB, nick, id, info.Thing)
if err != nil {
log.Error().
Err(err).
Str("subject", info.User).
Str("itemName", info.Thing).
Msg("error finding item")
w.WriteHeader(404)
fmt.Fprint(w, err)
return
}
if info.Action == "++" {
item.UpdateDelta(nil, 1)
} else if info.Action == "--" {
item.UpdateDelta(nil, -1)
} else {
w.WriteHeader(400)
fmt.Fprint(w, "Invalid increment")
return
}
}
all, err := GetAllItems(p.DB)
if err != nil {
w.WriteHeader(500)
fmt.Fprint(w, err)
return
}
data, err := json.Marshal(all)
if err != nil {
w.WriteHeader(500)
fmt.Fprint(w, err)
return
}
fmt.Fprint(w, string(data))
}
type Update struct {
Who string
What string
Amount int
}
type updateFunc func(bot.Request, Update) type updateFunc func(bot.Request, Update)
var updateFuncs = []updateFunc{} var updateFuncs = []updateFunc{}

View File

@ -14,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
@ -43,7 +44,7 @@ type alias struct {
} }
func (a *alias) resolve(db *sqlx.DB) (*Factoid, error) { func (a *alias) resolve(db *sqlx.DB) (*Factoid, error) {
// perform DB query to fill the To field // perform db query to fill the To field
q := `select fact, next from factoid_alias where fact=?` q := `select fact, next from factoid_alias where fact=?`
var next alias var next alias
err := db.Get(&next, q, a.Next) err := db.Get(&next, q, a.Next)
@ -801,10 +802,11 @@ func (p *FactoidPlugin) factTimer(c bot.Connector, channel string) {
// Register any web URLs desired // Register any web URLs desired
func (p *FactoidPlugin) registerWeb() { func (p *FactoidPlugin) registerWeb() {
http.HandleFunc("/factoid/api", p.serveAPI) r := chi.NewRouter()
http.HandleFunc("/factoid/req", p.serveQuery) r.HandleFunc("/api", p.serveAPI)
http.HandleFunc("/factoid", p.serveQuery) r.HandleFunc("/req", p.serveQuery)
p.Bot.RegisterWeb("/factoid", "Factoid") r.HandleFunc("/", p.serveQuery)
p.Bot.RegisterWeb(r, "/factoid", "Factoid")
} }
func linkify(text string) template.HTML { func linkify(text string) template.HTML {

View File

@ -7,19 +7,21 @@ import (
"net/url" "net/url"
"sort" "sort"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/velour/catbase/bot" "github.com/velour/catbase/bot"
) )
func (p *MemePlugin) registerWeb(c bot.Connector) { func (p *MemePlugin) registerWeb(c bot.Connector) {
http.HandleFunc("/slash/meme", p.slashMeme(c)) r := chi.NewRouter()
http.HandleFunc("/meme/img/", p.img) r.HandleFunc("/slash", p.slashMeme(c))
http.HandleFunc("/meme/all", p.all) r.HandleFunc("/img", p.img)
http.HandleFunc("/meme/add", p.addMeme) r.HandleFunc("/all", p.all)
http.HandleFunc("/meme/rm", p.rmMeme) r.HandleFunc("/add", p.addMeme)
http.HandleFunc("/meme", p.webRoot) r.HandleFunc("/rm", p.rmMeme)
p.bot.RegisterWeb("/meme", "Memes") r.HandleFunc("/", p.webRoot)
p.bot.RegisterWeb(r, "/meme", "Memes")
} }
type webResp struct { type webResp struct {

View File

@ -2,6 +2,6 @@ package webshit
//func TestWebshit_Check(t *testing.T) { //func TestWebshit_Check(t *testing.T) {
// mb := bot.NewMockBot() // mb := bot.NewMockBot()
// ws := New(mb.DB()) // ws := New(mb.db())
// ws.checkBids() // ws.checkBids()
//} //}

View File

@ -34,7 +34,7 @@ func makeMessage(nick, payload string, r *regexp.Regexp) bot.Request {
func makePlugin(t *testing.T) (*RememberPlugin, *fact.FactoidPlugin, *bot.MockBot) { func makePlugin(t *testing.T) (*RememberPlugin, *fact.FactoidPlugin, *bot.MockBot) {
mb := bot.NewMockBot() mb := bot.NewMockBot()
f := fact.New(mb) // for DB table f := fact.New(mb) // for db table
p := New(mb) p := New(mb)
assert.NotNil(t, p) assert.NotNil(t, p)
return p, f, mb return p, f, mb

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/go-chi/chi/v5"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/velour/catbase/bot" "github.com/velour/catbase/bot"
@ -28,16 +29,17 @@ func New(b bot.Bot) *SecretsPlugin {
} }
func (p *SecretsPlugin) registerWeb() { func (p *SecretsPlugin) registerWeb() {
http.HandleFunc("/secrets/add", p.handleRegister) r := chi.NewRouter()
http.HandleFunc("/secrets/remove", p.handleRemove) r.HandleFunc("/add", p.handleRegister)
http.HandleFunc("/secrets/all", p.handleAll) r.HandleFunc("/remove", p.handleRemove)
http.HandleFunc("/secrets/test", func(w http.ResponseWriter, r *http.Request) { r.HandleFunc("/all", p.handleAll)
r.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
value := r.URL.Query().Get("test") value := r.URL.Query().Get("test")
j, _ := json.Marshal(map[string]string{"value": value}) j, _ := json.Marshal(map[string]string{"value": value})
w.Write(j) w.Write(j)
}) })
http.HandleFunc("/secrets", p.handleIndex) r.HandleFunc("/", p.handleIndex)
p.b.RegisterWeb("/secrets", "Secrets") p.b.RegisterWeb(r, "/secrets", "Secrets")
} }
func (p *SecretsPlugin) registerSecret(key, value string) error { func (p *SecretsPlugin) registerSecret(key, value string) error {