Compare commits

..

6 Commits

Author SHA1 Message Date
Chris Sexton 2e0be87dd9 counter: fix ID null string regression
Also remove shared DB cache as it seems to not be necessary.
2022-07-17 14:46:09 -04:00
Chris Sexton 55ba4996eb bot: add override password for staging/testing modes 2022-07-17 14:46:09 -04:00
Chris Sexton 96229c9e11 counter: change http rate limiter criteria 2022-07-17 14:46:09 -04:00
Chris Sexton 72c32e1c16 counter: fix bug where anything but the right password worked 2022-07-17 14:46:09 -04:00
Chris Sexton df0b5d709b counter: fix bug where user ID was not passed to resolveUser 2022-07-17 14:46:09 -04:00
Chris Sexton aa7d377c08 counter: update model to reflect that the ID may be null 2022-07-17 14:46:09 -04:00
4 changed files with 44 additions and 18 deletions

View File

@ -318,6 +318,9 @@ func (b *bot) Register(p Plugin, kind Kind, cb Callback) {
// GetPassword returns a random password generated for the bot // GetPassword returns a random password generated for the bot
// Passwords expire in 24h and are used for the web interface // Passwords expire in 24h and are used for the web interface
func (b *bot) GetPassword() string { func (b *bot) GetPassword() string {
if override := b.config.Get("bot.password", ""); override != "" {
return override
}
if b.passwordCreated.Before(time.Now().Add(-24 * time.Hour)) { if b.passwordCreated.Before(time.Now().Add(-24 * time.Hour)) {
adjs := b.config.GetArray("bot.passwordAdjectives", []string{"very"}) adjs := b.config.GetArray("bot.passwordAdjectives", []string{"very"})
nouns := b.config.GetArray("bot.passwordNouns", []string{"noun"}) nouns := b.config.GetArray("bot.passwordNouns", []string{"noun"})
@ -404,10 +407,11 @@ func PluginName(p Plugin) string {
} }
func (b *bot) CheckPassword(secret, password string) bool { func (b *bot) CheckPassword(secret, password string) bool {
log.Debug().Msgf("CheckPassword(%s, %s) => b.password=%s, b.GetPassword()=%s", secret, password, b.password, b.GetPassword())
if password == "" { if password == "" {
return false return false
} }
if b.password == password { if b.GetPassword() == password {
return true return true
} }
parts := strings.SplitN(password, ":", 2) parts := strings.SplitN(password, ":", 2)

View File

@ -109,7 +109,7 @@ func (mb *MockBot) UploadEmojy(emojy, path string) error { return nil
func (mb *MockBot) RegisterFilter(s string, f func(string) string) {} func (mb *MockBot) RegisterFilter(s string, f func(string) string) {}
func NewMockBot() *MockBot { func NewMockBot() *MockBot {
cfg := config.ReadConfig("file::memory:?mode=memory&cache=shared") cfg := config.ReadConfig(":memory:")
b := MockBot{ b := MockBot{
Cfg: cfg, Cfg: cfg,
Messages: make([]string, 0), Messages: make([]string, 0),

View File

@ -4,6 +4,7 @@ import (
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/velour/catbase/bot/user"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
@ -23,11 +24,14 @@ func (p *CounterPlugin) registerWeb() {
r := chi.NewRouter() r := chi.NewRouter()
requests := p.cfg.GetInt("counter.requestsPer", 1) requests := p.cfg.GetInt("counter.requestsPer", 1)
seconds := p.cfg.GetInt("counter.seconds", 1) seconds := p.cfg.GetInt("counter.seconds", 1)
r.Use(httprate.LimitByIP(requests, time.Duration(seconds)*time.Second)) dur := time.Duration(seconds) * time.Second
r.HandleFunc("/api/users/{user}/items/{item}/increment/{delta}", p.mkIncrementByNAPI(1)) subrouter := chi.NewRouter()
r.HandleFunc("/api/users/{user}/items/{item}/decrement/{delta}", p.mkIncrementByNAPI(-1)) subrouter.Use(httprate.LimitByIP(requests, dur))
r.HandleFunc("/api/users/{user}/items/{item}/increment", p.mkIncrementAPI(1)) subrouter.HandleFunc("/api/users/{user}/items/{item}/increment/{delta}", p.mkIncrementByNAPI(1))
r.HandleFunc("/api/users/{user}/items/{item}/decrement", p.mkIncrementAPI(-1)) subrouter.HandleFunc("/api/users/{user}/items/{item}/decrement/{delta}", p.mkIncrementByNAPI(-1))
subrouter.HandleFunc("/api/users/{user}/items/{item}/increment", p.mkIncrementAPI(1))
subrouter.HandleFunc("/api/users/{user}/items/{item}/decrement", p.mkIncrementAPI(-1))
r.Mount("/", subrouter)
r.HandleFunc("/api", p.handleCounterAPI) r.HandleFunc("/api", p.handleCounterAPI)
r.HandleFunc("/", p.handleCounter) r.HandleFunc("/", p.handleCounter)
p.b.RegisterWebName(r, "/counter", "Counter") p.b.RegisterWebName(r, "/counter", "Counter")
@ -205,13 +209,25 @@ func (p *CounterPlugin) handleCounterAPI(w http.ResponseWriter, r *http.Request)
log.Debug(). log.Debug().
Interface("postbody", info). Interface("postbody", info).
Msg("Got a POST") Msg("Got a POST")
if p.b.CheckPassword("", info.Password) { if !p.b.CheckPassword("", info.Password) {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
j, _ := json.Marshal(struct{ Err string }{Err: "Invalid Password"}) j, _ := json.Marshal(struct{ Err string }{Err: "Invalid Password"})
w.Write(j) w.Write(j)
return return
} }
nick, id := p.resolveUser(bot.Request{Conn: p.b.DefaultConnector()}, info.User) req := bot.Request{
Conn: p.b.DefaultConnector(),
Kind: bot.Message,
Msg: msg.Message{
User: &user.User{
ID: "",
Name: info.User,
Admin: false,
},
},
}
// resolveUser requires a "full" request object so we are faking it
nick, id := p.resolveUser(req, info.User)
item, err := GetUserItem(p.db, nick, id, info.Thing) item, err := GetUserItem(p.db, nick, id, info.Thing)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -238,12 +254,14 @@ func (p *CounterPlugin) handleCounterAPI(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, err) fmt.Fprint(w, err)
log.Error().Err(err).Msg("Error getting items")
return return
} }
data, err := json.Marshal(all) data, err := json.Marshal(all)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, err) fmt.Fprint(w, err)
log.Error().Err(err).Msg("Error marshaling items")
return return
} }
fmt.Fprint(w, string(data)) fmt.Fprint(w, string(data))

View File

@ -32,7 +32,7 @@ type Item struct {
Nick string Nick string
Item string Item string
Count int Count int
UserID string UserID sql.NullString
} }
type alias struct { type alias struct {
@ -190,7 +190,7 @@ func GetUserItem(db *sqlx.DB, nick, id, itemName string) (Item, error) {
item.ID = -1 item.ID = -1
item.Nick = nick item.Nick = nick
item.Item = itemName item.Item = itemName
item.UserID = id item.UserID.Scan(id)
case nil: case nil:
default: default:
return Item{}, err return Item{}, err
@ -208,7 +208,7 @@ func GetUserItem(db *sqlx.DB, nick, id, itemName string) (Item, error) {
// Create saves a counter // Create saves a counter
func (i *Item) Create() error { func (i *Item) Create() error {
res, err := i.Exec(`insert into counter (nick, item, count, userid) values (?, ?, ?, ?);`, res, err := i.Exec(`insert into counter (nick, item, count, userid) values (?, ?, ?, ?);`,
i.Nick, i.Item, i.Count, i.UserID) i.Nick, i.Item, i.Count, i.UserID.String)
if err != nil { if err != nil {
return err return err
} }
@ -226,7 +226,9 @@ func (i *Item) Update(r *bot.Request, value int) error {
return i.Delete() return i.Delete()
} }
if i.ID == -1 { if i.ID == -1 {
i.Create() if err := i.Create(); err != nil {
return err
}
} }
log.Debug(). log.Debug().
Interface("i", i). Interface("i", i).
@ -286,21 +288,20 @@ func (p *CounterPlugin) migrate(r bot.Request) bool {
func setupDB(b bot.Bot) error { func setupDB(b bot.Bot) error {
db := b.DB() db := b.DB()
tx := db.MustBegin()
db.MustExec(`create table if not exists counter ( db.MustExec(`create table if not exists counter (
id integer primary key, id integer primary key,
nick string, nick string,
item string, item string,
count integer count integer,
userid string
);`) );`)
db.MustExec(`create table if not exists counter_alias ( db.MustExec(`create table if not exists counter_alias (
id integer PRIMARY KEY AUTOINCREMENT, id integer PRIMARY KEY AUTOINCREMENT,
item string NOT NULL UNIQUE, item string NOT NULL UNIQUE,
points_to string NOT NULL points_to string NOT NULL
);`) );`)
tx.Commit()
tx = db.MustBegin() tx := db.MustBegin()
count := 0 count := 0
err := tx.Get(&count, `SELECT count(*) FROM pragma_table_info('counter') where name='userid'`) err := tx.Get(&count, `SELECT count(*) FROM pragma_table_info('counter') where name='userid'`)
if err != nil { if err != nil {
@ -594,7 +595,10 @@ func (p *CounterPlugin) incrementCmd(r bot.Request) bool {
} }
log.Debug().Msgf("About to update item: %#v", item) log.Debug().Msgf("About to update item: %#v", item)
p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, item.Count+1, item.Item)) p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, item.Count+1, item.Item))
item.UpdateDelta(&r, 1) err = item.UpdateDelta(&r, 1)
if err != nil {
log.Error().Err(err).Msgf("Could not UpdateDelta")
}
return true return true
} }