mirror of https://github.com/velour/catbase.git
commit
2ed92927a2
42
bot/bot.go
42
bot/bot.go
|
@ -29,6 +29,9 @@ type bot struct {
|
||||||
// channel -> plugin
|
// channel -> plugin
|
||||||
pluginBlacklist map[string]bool
|
pluginBlacklist map[string]bool
|
||||||
|
|
||||||
|
// plugin, this is bot-wide
|
||||||
|
pluginWhitelist map[string]bool
|
||||||
|
|
||||||
// Users holds information about all of our friends
|
// Users holds information about all of our friends
|
||||||
users []user.User
|
users []user.User
|
||||||
// Represents the bot
|
// Represents the bot
|
||||||
|
@ -84,6 +87,7 @@ func New(config *config.Config, connector Connector) Bot {
|
||||||
plugins: make(map[string]Plugin),
|
plugins: make(map[string]Plugin),
|
||||||
pluginOrdering: make([]string, 0),
|
pluginOrdering: make([]string, 0),
|
||||||
pluginBlacklist: make(map[string]bool),
|
pluginBlacklist: make(map[string]bool),
|
||||||
|
pluginWhitelist: make(map[string]bool),
|
||||||
conn: connector,
|
conn: connector,
|
||||||
users: users,
|
users: users,
|
||||||
me: users[0],
|
me: users[0],
|
||||||
|
@ -97,6 +101,7 @@ func New(config *config.Config, connector Connector) Bot {
|
||||||
bot.migrateDB()
|
bot.migrateDB()
|
||||||
|
|
||||||
bot.RefreshPluginBlacklist()
|
bot.RefreshPluginBlacklist()
|
||||||
|
bot.RefreshPluginWhitelist()
|
||||||
|
|
||||||
http.HandleFunc("/", bot.serveRoot)
|
http.HandleFunc("/", bot.serveRoot)
|
||||||
|
|
||||||
|
@ -142,7 +147,12 @@ func (b *bot) migrateDB() {
|
||||||
name string,
|
name string,
|
||||||
primary key (channel, name)
|
primary key (channel, name)
|
||||||
);`); err != nil {
|
);`); err != nil {
|
||||||
log.Fatal().Err(err).Msgf("Initial DB migration create variables table")
|
log.Fatal().Err(err).Msgf("Initial DB migration create blacklist table")
|
||||||
|
}
|
||||||
|
if _, err := b.DB().Exec(`create table if not exists pluginWhitelist (
|
||||||
|
name string primary key
|
||||||
|
);`); err != nil {
|
||||||
|
log.Fatal().Err(err).Msgf("Initial DB migration create whitelist table")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,6 +315,24 @@ func (b *bot) RefreshPluginBlacklist() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshPluginWhitelist loads data for which plugins are enabled
|
||||||
|
func (b *bot) RefreshPluginWhitelist() error {
|
||||||
|
whitelistItems := []struct {
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
{Name: "admin"}, // we must always ensure admin is on!
|
||||||
|
}
|
||||||
|
if err := b.DB().Select(&whitelistItems, `select name from pluginWhitelist`); err != nil {
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
|
}
|
||||||
|
b.pluginWhitelist = make(map[string]bool)
|
||||||
|
for _, i := range whitelistItems {
|
||||||
|
b.pluginWhitelist[i.Name] = true
|
||||||
|
}
|
||||||
|
log.Debug().Interface("whitelist", b.pluginWhitelist).Msgf("Refreshed plugin whitelist")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetPluginNames returns an ordered list of plugins loaded (used for blacklisting)
|
// GetPluginNames returns an ordered list of plugins loaded (used for blacklisting)
|
||||||
func (b *bot) GetPluginNames() []string {
|
func (b *bot) GetPluginNames() []string {
|
||||||
names := []string{}
|
names := []string{}
|
||||||
|
@ -314,10 +342,22 @@ func (b *bot) GetPluginNames() []string {
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *bot) GetWhitelist() []string {
|
||||||
|
list := []string{}
|
||||||
|
for k := range b.pluginWhitelist {
|
||||||
|
list = append(list, k)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
func (b *bot) onBlacklist(channel, plugin string) bool {
|
func (b *bot) onBlacklist(channel, plugin string) bool {
|
||||||
return b.pluginBlacklist[channel+plugin]
|
return b.pluginBlacklist[channel+plugin]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *bot) onWhitelist(plugin string) bool {
|
||||||
|
return b.pluginWhitelist[plugin]
|
||||||
|
}
|
||||||
|
|
||||||
func pluginNameStem(name string) string {
|
func pluginNameStem(name string) string {
|
||||||
return strings.Split(strings.TrimPrefix(name, "*"), ".")[0]
|
return strings.Split(strings.TrimPrefix(name, "*"), ".")[0]
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,9 +31,8 @@ func (b *bot) Receive(conn Connector, kind Kind, msg msg.Message, args ...interf
|
||||||
goto RET
|
goto RET
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msgf("checking blacklist %v", b.pluginBlacklist)
|
|
||||||
for _, name := range b.pluginOrdering {
|
for _, name := range b.pluginOrdering {
|
||||||
if b.onBlacklist(msg.Channel, pluginNameStem(name)) {
|
if b.onBlacklist(msg.Channel, pluginNameStem(name)) || !b.onWhitelist(pluginNameStem(name)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if b.runCallback(conn, b.plugins[name], kind, msg, args...) {
|
if b.runCallback(conn, b.plugins[name], kind, msg, args...) {
|
||||||
|
|
|
@ -111,6 +111,12 @@ type Bot interface {
|
||||||
|
|
||||||
// RefreshPluginBlacklist reloads the list of plugins disabled per room from the DB
|
// RefreshPluginBlacklist reloads the list of plugins disabled per room from the DB
|
||||||
RefreshPluginBlacklist() error
|
RefreshPluginBlacklist() error
|
||||||
|
|
||||||
|
// RefreshPluginWhitelist reloads the list of plugins enabled from the DB
|
||||||
|
RefreshPluginWhitelist() error
|
||||||
|
|
||||||
|
// Get the contents of the white list
|
||||||
|
GetWhitelist() []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connector represents a server connection to a chat service
|
// Connector represents a server connection to a chat service
|
||||||
|
|
|
@ -115,3 +115,5 @@ func NewMockBot() *MockBot {
|
||||||
|
|
||||||
func (mb *MockBot) GetPluginNames() []string { return nil }
|
func (mb *MockBot) GetPluginNames() []string { return nil }
|
||||||
func (mb *MockBot) RefreshPluginBlacklist() error { return nil }
|
func (mb *MockBot) RefreshPluginBlacklist() error { return nil }
|
||||||
|
func (mb *MockBot) RefreshPluginWhitelist() error { return nil }
|
||||||
|
func (mb *MockBot) GetWhitelist() []string { return []string{} }
|
||||||
|
|
|
@ -218,18 +218,5 @@ func (d *Discord) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate
|
||||||
|
|
||||||
log.Debug().Interface("m", m).Interface("msg", msg).Msg("message received")
|
log.Debug().Interface("m", m).Interface("msg", msg).Msg("message received")
|
||||||
|
|
||||||
authorizedChannels := d.config.GetArray("channels", []string{})
|
|
||||||
|
|
||||||
if in(ch.Name, authorizedChannels) {
|
|
||||||
d.event(d, bot.Message, msg)
|
d.event(d, bot.Message, msg)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func in(s string, lst []string) bool {
|
|
||||||
for _, i := range lst {
|
|
||||||
if s == i {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
3
gok.sh
3
gok.sh
|
@ -15,8 +15,7 @@ gofmt -l $(find . -name '*.go') > $o 2>&1
|
||||||
test $(wc -l $o | awk '{ print $1 }') = "0" || fail
|
test $(wc -l $o | awk '{ print $1 }') = "0" || fail
|
||||||
|
|
||||||
echo govet
|
echo govet
|
||||||
#go vet ./... > $o 2>&1
|
go vet ./... > $o 2>&1
|
||||||
echo disabled
|
|
||||||
|
|
||||||
echo go test
|
echo go test
|
||||||
go test -test.timeout=60s ./... > $o 2>&1 || fail
|
go test -test.timeout=60s ./... > $o 2>&1 || fail
|
||||||
|
|
|
@ -55,6 +55,13 @@ var forbiddenKeys = map[string]bool{
|
||||||
var addBlacklist = regexp.MustCompile(`(?i)disable plugin (.*)`)
|
var addBlacklist = regexp.MustCompile(`(?i)disable plugin (.*)`)
|
||||||
var rmBlacklist = regexp.MustCompile(`(?i)enable plugin (.*)`)
|
var rmBlacklist = regexp.MustCompile(`(?i)enable plugin (.*)`)
|
||||||
|
|
||||||
|
var addWhitelist = regexp.MustCompile(`(?i)^whitelist plugin (.*)`)
|
||||||
|
var rmWhitelist = regexp.MustCompile(`(?i)^unwhitelist plugin (.*)`)
|
||||||
|
var allWhitelist = regexp.MustCompile(`(?i)^whitelist all`)
|
||||||
|
var allUnwhitelist = regexp.MustCompile(`(?i)^unwhitelist all`)
|
||||||
|
var getWhitelist = regexp.MustCompile(`(?i)^list whitelist`)
|
||||||
|
var getPlugins = regexp.MustCompile(`(?i)^list plugins`)
|
||||||
|
|
||||||
// Message responds to the bot hook on recieving messages.
|
// Message responds to the bot hook on recieving messages.
|
||||||
// This function returns true if the plugin responds in a meaningful way to the users message.
|
// This function returns true if the plugin responds in a meaningful way to the users message.
|
||||||
// Otherwise, the function returns false and the bot continues execution of other plugins.
|
// Otherwise, the function returns false and the bot continues execution of other plugins.
|
||||||
|
@ -131,6 +138,71 @@ func (p *AdminPlugin) message(conn bot.Connector, k bot.Kind, message msg.Messag
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if allWhitelist.MatchString(body) {
|
||||||
|
plugins := p.bot.GetPluginNames()
|
||||||
|
for _, plugin := range plugins {
|
||||||
|
if err := p.addWhitelist(plugin); err != nil {
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't whitelist that item: %s", err))
|
||||||
|
log.Error().Err(err).Msgf("error adding whitelist item")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, "Enabled all plugins")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if allUnwhitelist.MatchString(body) {
|
||||||
|
plugins := p.bot.GetPluginNames()
|
||||||
|
for _, plugin := range plugins {
|
||||||
|
if plugin == "admin" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := p.rmWhitelist(plugin); err != nil {
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't unwhitelist that item: %s", err))
|
||||||
|
log.Error().Err(err).Msgf("error removing whitelist item")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, "Disabled all plugins")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if addWhitelist.MatchString(body) {
|
||||||
|
submatches := addWhitelist.FindStringSubmatch(message.Body)
|
||||||
|
plugin := submatches[1]
|
||||||
|
if err := p.addWhitelist(plugin); err != nil {
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't whitelist that item: %s", err))
|
||||||
|
log.Error().Err(err).Msgf("error adding whitelist item")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("%s enabled. Use `!unwhitelist plugin %s` to disable it.", plugin, plugin))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if rmWhitelist.MatchString(body) {
|
||||||
|
submatches := rmWhitelist.FindStringSubmatch(message.Body)
|
||||||
|
plugin := submatches[1]
|
||||||
|
if err := p.rmWhitelist(plugin); err != nil {
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't unwhitelist that item: %s", err))
|
||||||
|
log.Error().Err(err).Msgf("error removing whitelist item")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("%s disabled. Use `!whitelist plugin %s` to enable it.", plugin, plugin))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if getWhitelist.MatchString(body) {
|
||||||
|
list := p.bot.GetWhitelist()
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("Whitelist: %v", list))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if getPlugins.MatchString(body) {
|
||||||
|
plugins := p.bot.GetPluginNames()
|
||||||
|
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("Plugins: %v", plugins))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
if strings.ToLower(body) == "password" {
|
if strings.ToLower(body) == "password" {
|
||||||
p.bot.Send(conn, bot.Message, message.Channel, p.bot.GetPassword())
|
p.bot.Send(conn, bot.Message, message.Channel, p.bot.GetPassword())
|
||||||
return true
|
return true
|
||||||
|
@ -283,25 +355,43 @@ func (p *AdminPlugin) handleWebAPI(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "%s", j)
|
fmt.Fprintf(w, "%s", j)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *AdminPlugin) addWhitelist(plugin string) error {
|
||||||
|
return p.modList(`insert or replace into pluginWhitelist values (?)`, "", plugin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AdminPlugin) rmWhitelist(plugin string) error {
|
||||||
|
if plugin == "admin" {
|
||||||
|
return fmt.Errorf("you cannot disable the admin plugin")
|
||||||
|
}
|
||||||
|
return p.modList(`delete from pluginWhitelist where name=?`, "", plugin)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *AdminPlugin) addBlacklist(channel, plugin string) error {
|
func (p *AdminPlugin) addBlacklist(channel, plugin string) error {
|
||||||
if plugin == "admin" {
|
if plugin == "admin" {
|
||||||
return fmt.Errorf("you cannot disable the admin plugin")
|
return fmt.Errorf("you cannot disable the admin plugin")
|
||||||
}
|
}
|
||||||
return p.modBlacklist(`insert or replace into pluginBlacklist values (?, ?)`, channel, plugin)
|
return p.modList(`insert or replace into pluginBlacklist values (?, ?)`, channel, plugin)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AdminPlugin) rmBlacklist(channel, plugin string) error {
|
func (p *AdminPlugin) rmBlacklist(channel, plugin string) error {
|
||||||
return p.modBlacklist(`delete from pluginBlacklist where channel=? and name=?`, channel, plugin)
|
return p.modList(`delete from pluginBlacklist where channel=? and name=?`, channel, plugin)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *AdminPlugin) modBlacklist(query, channel, plugin string) error {
|
func (p *AdminPlugin) modList(query, channel, plugin string) error {
|
||||||
|
if channel == "" && plugin != "" {
|
||||||
|
channel = plugin // hack
|
||||||
|
}
|
||||||
plugins := p.bot.GetPluginNames()
|
plugins := p.bot.GetPluginNames()
|
||||||
for _, pp := range plugins {
|
for _, pp := range plugins {
|
||||||
if pp == plugin {
|
if pp == plugin {
|
||||||
if _, err := p.db.Exec(query, channel, plugin); err != nil {
|
if _, err := p.db.Exec(query, channel, plugin); err != nil {
|
||||||
return fmt.Errorf("%w", err)
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
err := p.bot.RefreshPluginBlacklist()
|
err := p.bot.RefreshPluginWhitelist()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w", err)
|
||||||
|
}
|
||||||
|
err = p.bot.RefreshPluginBlacklist()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w", err)
|
return fmt.Errorf("%w", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue