diff --git a/bot/bot.go b/bot/bot.go index 5eb0181..5670967 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -29,6 +29,9 @@ type bot struct { // channel -> plugin pluginBlacklist map[string]bool + // plugin, this is bot-wide + pluginWhitelist map[string]bool + // Users holds information about all of our friends users []user.User // Represents the bot @@ -84,6 +87,7 @@ func New(config *config.Config, connector Connector) Bot { plugins: make(map[string]Plugin), pluginOrdering: make([]string, 0), pluginBlacklist: make(map[string]bool), + pluginWhitelist: make(map[string]bool), conn: connector, users: users, me: users[0], @@ -97,6 +101,7 @@ func New(config *config.Config, connector Connector) Bot { bot.migrateDB() bot.RefreshPluginBlacklist() + bot.RefreshPluginWhitelist() http.HandleFunc("/", bot.serveRoot) @@ -142,7 +147,12 @@ func (b *bot) migrateDB() { name string, primary key (channel, name) );`); err != nil { - log.Fatal().Err(err).Msgf("Initial DB migration create variables table") + log.Fatal().Err(err).Msgf("Initial DB migration create blacklist table") + } + if _, err := b.DB().Exec(`create table if not exists pluginWhitelist ( + name string primary key + );`); err != nil { + log.Fatal().Err(err).Msgf("Initial DB migration create whitelist table") } } @@ -305,6 +315,24 @@ func (b *bot) RefreshPluginBlacklist() error { return nil } +// RefreshPluginWhitelist loads data for which plugins are enabled +func (b *bot) RefreshPluginWhitelist() error { + whitelistItems := []struct { + Name string + }{ + {Name: "admin"}, // we must always ensure admin is on! + } + if err := b.DB().Select(&whitelistItems, `select name from pluginWhitelist`); err != nil { + return fmt.Errorf("%w", err) + } + b.pluginWhitelist = make(map[string]bool) + for _, i := range whitelistItems { + b.pluginWhitelist[i.Name] = true + } + log.Debug().Interface("whitelist", b.pluginWhitelist).Msgf("Refreshed plugin whitelist") + return nil +} + // GetPluginNames returns an ordered list of plugins loaded (used for blacklisting) func (b *bot) GetPluginNames() []string { names := []string{} @@ -314,10 +342,22 @@ func (b *bot) GetPluginNames() []string { return names } +func (b *bot) GetWhitelist() []string { + list := []string{} + for k := range b.pluginWhitelist { + list = append(list, k) + } + return list +} + func (b *bot) onBlacklist(channel, plugin string) bool { return b.pluginBlacklist[channel+plugin] } +func (b *bot) onWhitelist(plugin string) bool { + return b.pluginWhitelist[plugin] +} + func pluginNameStem(name string) string { return strings.Split(strings.TrimPrefix(name, "*"), ".")[0] } diff --git a/bot/handlers.go b/bot/handlers.go index 0508f97..9e94ba2 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -31,9 +31,8 @@ func (b *bot) Receive(conn Connector, kind Kind, msg msg.Message, args ...interf goto RET } - log.Debug().Msgf("checking blacklist %v", b.pluginBlacklist) for _, name := range b.pluginOrdering { - if b.onBlacklist(msg.Channel, pluginNameStem(name)) { + if b.onBlacklist(msg.Channel, pluginNameStem(name)) || !b.onWhitelist(pluginNameStem(name)) { continue } if b.runCallback(conn, b.plugins[name], kind, msg, args...) { diff --git a/bot/interfaces.go b/bot/interfaces.go index af93915..5ca34ff 100644 --- a/bot/interfaces.go +++ b/bot/interfaces.go @@ -111,6 +111,12 @@ type Bot interface { // RefreshPluginBlacklist reloads the list of plugins disabled per room from the DB RefreshPluginBlacklist() error + + // RefreshPluginWhitelist reloads the list of plugins enabled from the DB + RefreshPluginWhitelist() error + + // Get the contents of the white list + GetWhitelist() []string } // Connector represents a server connection to a chat service diff --git a/bot/mock.go b/bot/mock.go index 7179482..679a687 100644 --- a/bot/mock.go +++ b/bot/mock.go @@ -115,3 +115,5 @@ func NewMockBot() *MockBot { func (mb *MockBot) GetPluginNames() []string { return nil } func (mb *MockBot) RefreshPluginBlacklist() error { return nil } +func (mb *MockBot) RefreshPluginWhitelist() error { return nil } +func (mb *MockBot) GetWhitelist() []string { return []string{} } diff --git a/connectors/discord/discord.go b/connectors/discord/discord.go index 19e8d55..fc5d141 100644 --- a/connectors/discord/discord.go +++ b/connectors/discord/discord.go @@ -218,18 +218,5 @@ func (d *Discord) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate log.Debug().Interface("m", m).Interface("msg", msg).Msg("message received") - authorizedChannels := d.config.GetArray("channels", []string{}) - - if in(ch.Name, authorizedChannels) { - d.event(d, bot.Message, msg) - } -} - -func in(s string, lst []string) bool { - for _, i := range lst { - if s == i { - return true - } - } - return false + d.event(d, bot.Message, msg) } diff --git a/gok.sh b/gok.sh index ed314a3..84cfe04 100755 --- a/gok.sh +++ b/gok.sh @@ -15,8 +15,7 @@ gofmt -l $(find . -name '*.go') > $o 2>&1 test $(wc -l $o | awk '{ print $1 }') = "0" || fail echo govet -#go vet ./... > $o 2>&1 -echo disabled +go vet ./... > $o 2>&1 echo go test go test -test.timeout=60s ./... > $o 2>&1 || fail diff --git a/plugins/admin/admin.go b/plugins/admin/admin.go index 7c304e0..20369bc 100644 --- a/plugins/admin/admin.go +++ b/plugins/admin/admin.go @@ -55,6 +55,13 @@ var forbiddenKeys = map[string]bool{ var addBlacklist = regexp.MustCompile(`(?i)disable plugin (.*)`) var rmBlacklist = regexp.MustCompile(`(?i)enable plugin (.*)`) +var addWhitelist = regexp.MustCompile(`(?i)^whitelist plugin (.*)`) +var rmWhitelist = regexp.MustCompile(`(?i)^unwhitelist plugin (.*)`) +var allWhitelist = regexp.MustCompile(`(?i)^whitelist all`) +var allUnwhitelist = regexp.MustCompile(`(?i)^unwhitelist all`) +var getWhitelist = regexp.MustCompile(`(?i)^list whitelist`) +var getPlugins = regexp.MustCompile(`(?i)^list plugins`) + // Message responds to the bot hook on recieving messages. // This function returns true if the plugin responds in a meaningful way to the users message. // Otherwise, the function returns false and the bot continues execution of other plugins. @@ -131,6 +138,71 @@ func (p *AdminPlugin) message(conn bot.Connector, k bot.Kind, message msg.Messag return true } + if allWhitelist.MatchString(body) { + plugins := p.bot.GetPluginNames() + for _, plugin := range plugins { + if err := p.addWhitelist(plugin); err != nil { + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't whitelist that item: %s", err)) + log.Error().Err(err).Msgf("error adding whitelist item") + return true + } + } + p.bot.Send(conn, bot.Message, message.Channel, "Enabled all plugins") + return true + } + + if allUnwhitelist.MatchString(body) { + plugins := p.bot.GetPluginNames() + for _, plugin := range plugins { + if plugin == "admin" { + continue + } + if err := p.rmWhitelist(plugin); err != nil { + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't unwhitelist that item: %s", err)) + log.Error().Err(err).Msgf("error removing whitelist item") + return true + } + } + p.bot.Send(conn, bot.Message, message.Channel, "Disabled all plugins") + return true + } + + if addWhitelist.MatchString(body) { + submatches := addWhitelist.FindStringSubmatch(message.Body) + plugin := submatches[1] + if err := p.addWhitelist(plugin); err != nil { + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't whitelist that item: %s", err)) + log.Error().Err(err).Msgf("error adding whitelist item") + return true + } + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("%s enabled. Use `!unwhitelist plugin %s` to disable it.", plugin, plugin)) + return true + } + + if rmWhitelist.MatchString(body) { + submatches := rmWhitelist.FindStringSubmatch(message.Body) + plugin := submatches[1] + if err := p.rmWhitelist(plugin); err != nil { + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't unwhitelist that item: %s", err)) + log.Error().Err(err).Msgf("error removing whitelist item") + return true + } + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("%s disabled. Use `!whitelist plugin %s` to enable it.", plugin, plugin)) + return true + } + + if getWhitelist.MatchString(body) { + list := p.bot.GetWhitelist() + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("Whitelist: %v", list)) + return true + } + + if getPlugins.MatchString(body) { + plugins := p.bot.GetPluginNames() + p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("Plugins: %v", plugins)) + return true + } + if strings.ToLower(body) == "password" { p.bot.Send(conn, bot.Message, message.Channel, p.bot.GetPassword()) return true @@ -283,25 +355,43 @@ func (p *AdminPlugin) handleWebAPI(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "%s", j) } +func (p *AdminPlugin) addWhitelist(plugin string) error { + return p.modList(`insert or replace into pluginWhitelist values (?)`, "", plugin) +} + +func (p *AdminPlugin) rmWhitelist(plugin string) error { + if plugin == "admin" { + return fmt.Errorf("you cannot disable the admin plugin") + } + return p.modList(`delete from pluginWhitelist where name=?`, "", plugin) +} + func (p *AdminPlugin) addBlacklist(channel, plugin string) error { if plugin == "admin" { return fmt.Errorf("you cannot disable the admin plugin") } - return p.modBlacklist(`insert or replace into pluginBlacklist values (?, ?)`, channel, plugin) + return p.modList(`insert or replace into pluginBlacklist values (?, ?)`, channel, plugin) } func (p *AdminPlugin) rmBlacklist(channel, plugin string) error { - return p.modBlacklist(`delete from pluginBlacklist where channel=? and name=?`, channel, plugin) + return p.modList(`delete from pluginBlacklist where channel=? and name=?`, channel, plugin) } -func (p *AdminPlugin) modBlacklist(query, channel, plugin string) error { +func (p *AdminPlugin) modList(query, channel, plugin string) error { + if channel == "" && plugin != "" { + channel = plugin // hack + } plugins := p.bot.GetPluginNames() for _, pp := range plugins { if pp == plugin { if _, err := p.db.Exec(query, channel, plugin); err != nil { return fmt.Errorf("%w", err) } - err := p.bot.RefreshPluginBlacklist() + err := p.bot.RefreshPluginWhitelist() + if err != nil { + return fmt.Errorf("%w", err) + } + err = p.bot.RefreshPluginBlacklist() if err != nil { return fmt.Errorf("%w", err) }