Compare commits

..

No commits in common. "2e0be87dd97666aa0ee15ee7548bdd22d75522a0" and "3f7598e92cfb0c37493f2e6e5f9aaa70bd2d0872" have entirely different histories.

4 changed files with 18 additions and 44 deletions

View File

@ -318,9 +318,6 @@ func (b *bot) Register(p Plugin, kind Kind, cb Callback) {
// GetPassword returns a random password generated for the bot // GetPassword returns a random password generated for the bot
// Passwords expire in 24h and are used for the web interface // Passwords expire in 24h and are used for the web interface
func (b *bot) GetPassword() string { func (b *bot) GetPassword() string {
if override := b.config.Get("bot.password", ""); override != "" {
return override
}
if b.passwordCreated.Before(time.Now().Add(-24 * time.Hour)) { if b.passwordCreated.Before(time.Now().Add(-24 * time.Hour)) {
adjs := b.config.GetArray("bot.passwordAdjectives", []string{"very"}) adjs := b.config.GetArray("bot.passwordAdjectives", []string{"very"})
nouns := b.config.GetArray("bot.passwordNouns", []string{"noun"}) nouns := b.config.GetArray("bot.passwordNouns", []string{"noun"})
@ -407,11 +404,10 @@ func PluginName(p Plugin) string {
} }
func (b *bot) CheckPassword(secret, password string) bool { func (b *bot) CheckPassword(secret, password string) bool {
log.Debug().Msgf("CheckPassword(%s, %s) => b.password=%s, b.GetPassword()=%s", secret, password, b.password, b.GetPassword())
if password == "" { if password == "" {
return false return false
} }
if b.GetPassword() == password { if b.password == password {
return true return true
} }
parts := strings.SplitN(password, ":", 2) parts := strings.SplitN(password, ":", 2)

View File

@ -109,7 +109,7 @@ func (mb *MockBot) UploadEmojy(emojy, path string) error { return nil
func (mb *MockBot) RegisterFilter(s string, f func(string) string) {} func (mb *MockBot) RegisterFilter(s string, f func(string) string) {}
func NewMockBot() *MockBot { func NewMockBot() *MockBot {
cfg := config.ReadConfig(":memory:") cfg := config.ReadConfig("file::memory:?mode=memory&cache=shared")
b := MockBot{ b := MockBot{
Cfg: cfg, Cfg: cfg,
Messages: make([]string, 0), Messages: make([]string, 0),

View File

@ -4,7 +4,6 @@ import (
"embed" "embed"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/velour/catbase/bot/user"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
@ -24,14 +23,11 @@ func (p *CounterPlugin) registerWeb() {
r := chi.NewRouter() r := chi.NewRouter()
requests := p.cfg.GetInt("counter.requestsPer", 1) requests := p.cfg.GetInt("counter.requestsPer", 1)
seconds := p.cfg.GetInt("counter.seconds", 1) seconds := p.cfg.GetInt("counter.seconds", 1)
dur := time.Duration(seconds) * time.Second r.Use(httprate.LimitByIP(requests, time.Duration(seconds)*time.Second))
subrouter := chi.NewRouter() r.HandleFunc("/api/users/{user}/items/{item}/increment/{delta}", p.mkIncrementByNAPI(1))
subrouter.Use(httprate.LimitByIP(requests, dur)) r.HandleFunc("/api/users/{user}/items/{item}/decrement/{delta}", p.mkIncrementByNAPI(-1))
subrouter.HandleFunc("/api/users/{user}/items/{item}/increment/{delta}", p.mkIncrementByNAPI(1)) r.HandleFunc("/api/users/{user}/items/{item}/increment", p.mkIncrementAPI(1))
subrouter.HandleFunc("/api/users/{user}/items/{item}/decrement/{delta}", p.mkIncrementByNAPI(-1)) r.HandleFunc("/api/users/{user}/items/{item}/decrement", p.mkIncrementAPI(-1))
subrouter.HandleFunc("/api/users/{user}/items/{item}/increment", p.mkIncrementAPI(1))
subrouter.HandleFunc("/api/users/{user}/items/{item}/decrement", p.mkIncrementAPI(-1))
r.Mount("/", subrouter)
r.HandleFunc("/api", p.handleCounterAPI) r.HandleFunc("/api", p.handleCounterAPI)
r.HandleFunc("/", p.handleCounter) r.HandleFunc("/", p.handleCounter)
p.b.RegisterWebName(r, "/counter", "Counter") p.b.RegisterWebName(r, "/counter", "Counter")
@ -209,25 +205,13 @@ func (p *CounterPlugin) handleCounterAPI(w http.ResponseWriter, r *http.Request)
log.Debug(). log.Debug().
Interface("postbody", info). Interface("postbody", info).
Msg("Got a POST") Msg("Got a POST")
if !p.b.CheckPassword("", info.Password) { if p.b.CheckPassword("", info.Password) {
w.WriteHeader(http.StatusForbidden) w.WriteHeader(http.StatusForbidden)
j, _ := json.Marshal(struct{ Err string }{Err: "Invalid Password"}) j, _ := json.Marshal(struct{ Err string }{Err: "Invalid Password"})
w.Write(j) w.Write(j)
return return
} }
req := bot.Request{ nick, id := p.resolveUser(bot.Request{Conn: p.b.DefaultConnector()}, info.User)
Conn: p.b.DefaultConnector(),
Kind: bot.Message,
Msg: msg.Message{
User: &user.User{
ID: "",
Name: info.User,
Admin: false,
},
},
}
// resolveUser requires a "full" request object so we are faking it
nick, id := p.resolveUser(req, info.User)
item, err := GetUserItem(p.db, nick, id, info.Thing) item, err := GetUserItem(p.db, nick, id, info.Thing)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -254,14 +238,12 @@ func (p *CounterPlugin) handleCounterAPI(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, err) fmt.Fprint(w, err)
log.Error().Err(err).Msg("Error getting items")
return return
} }
data, err := json.Marshal(all) data, err := json.Marshal(all)
if err != nil { if err != nil {
w.WriteHeader(500) w.WriteHeader(500)
fmt.Fprint(w, err) fmt.Fprint(w, err)
log.Error().Err(err).Msg("Error marshaling items")
return return
} }
fmt.Fprint(w, string(data)) fmt.Fprint(w, string(data))

View File

@ -32,7 +32,7 @@ type Item struct {
Nick string Nick string
Item string Item string
Count int Count int
UserID sql.NullString UserID string
} }
type alias struct { type alias struct {
@ -190,7 +190,7 @@ func GetUserItem(db *sqlx.DB, nick, id, itemName string) (Item, error) {
item.ID = -1 item.ID = -1
item.Nick = nick item.Nick = nick
item.Item = itemName item.Item = itemName
item.UserID.Scan(id) item.UserID = id
case nil: case nil:
default: default:
return Item{}, err return Item{}, err
@ -208,7 +208,7 @@ func GetUserItem(db *sqlx.DB, nick, id, itemName string) (Item, error) {
// Create saves a counter // Create saves a counter
func (i *Item) Create() error { func (i *Item) Create() error {
res, err := i.Exec(`insert into counter (nick, item, count, userid) values (?, ?, ?, ?);`, res, err := i.Exec(`insert into counter (nick, item, count, userid) values (?, ?, ?, ?);`,
i.Nick, i.Item, i.Count, i.UserID.String) i.Nick, i.Item, i.Count, i.UserID)
if err != nil { if err != nil {
return err return err
} }
@ -226,9 +226,7 @@ func (i *Item) Update(r *bot.Request, value int) error {
return i.Delete() return i.Delete()
} }
if i.ID == -1 { if i.ID == -1 {
if err := i.Create(); err != nil { i.Create()
return err
}
} }
log.Debug(). log.Debug().
Interface("i", i). Interface("i", i).
@ -288,20 +286,21 @@ func (p *CounterPlugin) migrate(r bot.Request) bool {
func setupDB(b bot.Bot) error { func setupDB(b bot.Bot) error {
db := b.DB() db := b.DB()
tx := db.MustBegin()
db.MustExec(`create table if not exists counter ( db.MustExec(`create table if not exists counter (
id integer primary key, id integer primary key,
nick string, nick string,
item string, item string,
count integer, count integer
userid string
);`) );`)
db.MustExec(`create table if not exists counter_alias ( db.MustExec(`create table if not exists counter_alias (
id integer PRIMARY KEY AUTOINCREMENT, id integer PRIMARY KEY AUTOINCREMENT,
item string NOT NULL UNIQUE, item string NOT NULL UNIQUE,
points_to string NOT NULL points_to string NOT NULL
);`) );`)
tx.Commit()
tx := db.MustBegin() tx = db.MustBegin()
count := 0 count := 0
err := tx.Get(&count, `SELECT count(*) FROM pragma_table_info('counter') where name='userid'`) err := tx.Get(&count, `SELECT count(*) FROM pragma_table_info('counter') where name='userid'`)
if err != nil { if err != nil {
@ -595,10 +594,7 @@ func (p *CounterPlugin) incrementCmd(r bot.Request) bool {
} }
log.Debug().Msgf("About to update item: %#v", item) log.Debug().Msgf("About to update item: %#v", item)
p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, item.Count+1, item.Item)) p.b.Send(r.Conn, bot.Message, channel, fmt.Sprintf("%s has %d %s.", nick, item.Count+1, item.Item))
err = item.UpdateDelta(&r, 1) item.UpdateDelta(&r, 1)
if err != nil {
log.Error().Err(err).Msgf("Could not UpdateDelta")
}
return true return true
} }