Merge pull request #298 from velour/whitelist

Whitelist
This commit is contained in:
Chris Sexton 2020-10-09 12:07:32 -04:00 committed by GitHub
commit 2ed92927a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 146 additions and 23 deletions

View File

@ -29,6 +29,9 @@ type bot struct {
// channel -> plugin // channel -> plugin
pluginBlacklist map[string]bool pluginBlacklist map[string]bool
// plugin, this is bot-wide
pluginWhitelist map[string]bool
// Users holds information about all of our friends // Users holds information about all of our friends
users []user.User users []user.User
// Represents the bot // Represents the bot
@ -84,6 +87,7 @@ func New(config *config.Config, connector Connector) Bot {
plugins: make(map[string]Plugin), plugins: make(map[string]Plugin),
pluginOrdering: make([]string, 0), pluginOrdering: make([]string, 0),
pluginBlacklist: make(map[string]bool), pluginBlacklist: make(map[string]bool),
pluginWhitelist: make(map[string]bool),
conn: connector, conn: connector,
users: users, users: users,
me: users[0], me: users[0],
@ -97,6 +101,7 @@ func New(config *config.Config, connector Connector) Bot {
bot.migrateDB() bot.migrateDB()
bot.RefreshPluginBlacklist() bot.RefreshPluginBlacklist()
bot.RefreshPluginWhitelist()
http.HandleFunc("/", bot.serveRoot) http.HandleFunc("/", bot.serveRoot)
@ -142,7 +147,12 @@ func (b *bot) migrateDB() {
name string, name string,
primary key (channel, name) primary key (channel, name)
);`); err != nil { );`); err != nil {
log.Fatal().Err(err).Msgf("Initial DB migration create variables table") log.Fatal().Err(err).Msgf("Initial DB migration create blacklist table")
}
if _, err := b.DB().Exec(`create table if not exists pluginWhitelist (
name string primary key
);`); err != nil {
log.Fatal().Err(err).Msgf("Initial DB migration create whitelist table")
} }
} }
@ -305,6 +315,24 @@ func (b *bot) RefreshPluginBlacklist() error {
return nil return nil
} }
// RefreshPluginWhitelist loads data for which plugins are enabled
func (b *bot) RefreshPluginWhitelist() error {
whitelistItems := []struct {
Name string
}{
{Name: "admin"}, // we must always ensure admin is on!
}
if err := b.DB().Select(&whitelistItems, `select name from pluginWhitelist`); err != nil {
return fmt.Errorf("%w", err)
}
b.pluginWhitelist = make(map[string]bool)
for _, i := range whitelistItems {
b.pluginWhitelist[i.Name] = true
}
log.Debug().Interface("whitelist", b.pluginWhitelist).Msgf("Refreshed plugin whitelist")
return nil
}
// GetPluginNames returns an ordered list of plugins loaded (used for blacklisting) // GetPluginNames returns an ordered list of plugins loaded (used for blacklisting)
func (b *bot) GetPluginNames() []string { func (b *bot) GetPluginNames() []string {
names := []string{} names := []string{}
@ -314,10 +342,22 @@ func (b *bot) GetPluginNames() []string {
return names return names
} }
func (b *bot) GetWhitelist() []string {
list := []string{}
for k := range b.pluginWhitelist {
list = append(list, k)
}
return list
}
func (b *bot) onBlacklist(channel, plugin string) bool { func (b *bot) onBlacklist(channel, plugin string) bool {
return b.pluginBlacklist[channel+plugin] return b.pluginBlacklist[channel+plugin]
} }
func (b *bot) onWhitelist(plugin string) bool {
return b.pluginWhitelist[plugin]
}
func pluginNameStem(name string) string { func pluginNameStem(name string) string {
return strings.Split(strings.TrimPrefix(name, "*"), ".")[0] return strings.Split(strings.TrimPrefix(name, "*"), ".")[0]
} }

View File

@ -31,9 +31,8 @@ func (b *bot) Receive(conn Connector, kind Kind, msg msg.Message, args ...interf
goto RET goto RET
} }
log.Debug().Msgf("checking blacklist %v", b.pluginBlacklist)
for _, name := range b.pluginOrdering { for _, name := range b.pluginOrdering {
if b.onBlacklist(msg.Channel, pluginNameStem(name)) { if b.onBlacklist(msg.Channel, pluginNameStem(name)) || !b.onWhitelist(pluginNameStem(name)) {
continue continue
} }
if b.runCallback(conn, b.plugins[name], kind, msg, args...) { if b.runCallback(conn, b.plugins[name], kind, msg, args...) {

View File

@ -111,6 +111,12 @@ type Bot interface {
// RefreshPluginBlacklist reloads the list of plugins disabled per room from the DB // RefreshPluginBlacklist reloads the list of plugins disabled per room from the DB
RefreshPluginBlacklist() error RefreshPluginBlacklist() error
// RefreshPluginWhitelist reloads the list of plugins enabled from the DB
RefreshPluginWhitelist() error
// Get the contents of the white list
GetWhitelist() []string
} }
// Connector represents a server connection to a chat service // Connector represents a server connection to a chat service

View File

@ -115,3 +115,5 @@ func NewMockBot() *MockBot {
func (mb *MockBot) GetPluginNames() []string { return nil } func (mb *MockBot) GetPluginNames() []string { return nil }
func (mb *MockBot) RefreshPluginBlacklist() error { return nil } func (mb *MockBot) RefreshPluginBlacklist() error { return nil }
func (mb *MockBot) RefreshPluginWhitelist() error { return nil }
func (mb *MockBot) GetWhitelist() []string { return []string{} }

View File

@ -218,18 +218,5 @@ func (d *Discord) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate
log.Debug().Interface("m", m).Interface("msg", msg).Msg("message received") log.Debug().Interface("m", m).Interface("msg", msg).Msg("message received")
authorizedChannels := d.config.GetArray("channels", []string{}) d.event(d, bot.Message, msg)
if in(ch.Name, authorizedChannels) {
d.event(d, bot.Message, msg)
}
}
func in(s string, lst []string) bool {
for _, i := range lst {
if s == i {
return true
}
}
return false
} }

3
gok.sh
View File

@ -15,8 +15,7 @@ gofmt -l $(find . -name '*.go') > $o 2>&1
test $(wc -l $o | awk '{ print $1 }') = "0" || fail test $(wc -l $o | awk '{ print $1 }') = "0" || fail
echo govet echo govet
#go vet ./... > $o 2>&1 go vet ./... > $o 2>&1
echo disabled
echo go test echo go test
go test -test.timeout=60s ./... > $o 2>&1 || fail go test -test.timeout=60s ./... > $o 2>&1 || fail

View File

@ -55,6 +55,13 @@ var forbiddenKeys = map[string]bool{
var addBlacklist = regexp.MustCompile(`(?i)disable plugin (.*)`) var addBlacklist = regexp.MustCompile(`(?i)disable plugin (.*)`)
var rmBlacklist = regexp.MustCompile(`(?i)enable plugin (.*)`) var rmBlacklist = regexp.MustCompile(`(?i)enable plugin (.*)`)
var addWhitelist = regexp.MustCompile(`(?i)^whitelist plugin (.*)`)
var rmWhitelist = regexp.MustCompile(`(?i)^unwhitelist plugin (.*)`)
var allWhitelist = regexp.MustCompile(`(?i)^whitelist all`)
var allUnwhitelist = regexp.MustCompile(`(?i)^unwhitelist all`)
var getWhitelist = regexp.MustCompile(`(?i)^list whitelist`)
var getPlugins = regexp.MustCompile(`(?i)^list plugins`)
// Message responds to the bot hook on recieving messages. // Message responds to the bot hook on recieving messages.
// This function returns true if the plugin responds in a meaningful way to the users message. // This function returns true if the plugin responds in a meaningful way to the users message.
// Otherwise, the function returns false and the bot continues execution of other plugins. // Otherwise, the function returns false and the bot continues execution of other plugins.
@ -131,6 +138,71 @@ func (p *AdminPlugin) message(conn bot.Connector, k bot.Kind, message msg.Messag
return true return true
} }
if allWhitelist.MatchString(body) {
plugins := p.bot.GetPluginNames()
for _, plugin := range plugins {
if err := p.addWhitelist(plugin); err != nil {
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't whitelist that item: %s", err))
log.Error().Err(err).Msgf("error adding whitelist item")
return true
}
}
p.bot.Send(conn, bot.Message, message.Channel, "Enabled all plugins")
return true
}
if allUnwhitelist.MatchString(body) {
plugins := p.bot.GetPluginNames()
for _, plugin := range plugins {
if plugin == "admin" {
continue
}
if err := p.rmWhitelist(plugin); err != nil {
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't unwhitelist that item: %s", err))
log.Error().Err(err).Msgf("error removing whitelist item")
return true
}
}
p.bot.Send(conn, bot.Message, message.Channel, "Disabled all plugins")
return true
}
if addWhitelist.MatchString(body) {
submatches := addWhitelist.FindStringSubmatch(message.Body)
plugin := submatches[1]
if err := p.addWhitelist(plugin); err != nil {
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't whitelist that item: %s", err))
log.Error().Err(err).Msgf("error adding whitelist item")
return true
}
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("%s enabled. Use `!unwhitelist plugin %s` to disable it.", plugin, plugin))
return true
}
if rmWhitelist.MatchString(body) {
submatches := rmWhitelist.FindStringSubmatch(message.Body)
plugin := submatches[1]
if err := p.rmWhitelist(plugin); err != nil {
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("I couldn't unwhitelist that item: %s", err))
log.Error().Err(err).Msgf("error removing whitelist item")
return true
}
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("%s disabled. Use `!whitelist plugin %s` to enable it.", plugin, plugin))
return true
}
if getWhitelist.MatchString(body) {
list := p.bot.GetWhitelist()
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("Whitelist: %v", list))
return true
}
if getPlugins.MatchString(body) {
plugins := p.bot.GetPluginNames()
p.bot.Send(conn, bot.Message, message.Channel, fmt.Sprintf("Plugins: %v", plugins))
return true
}
if strings.ToLower(body) == "password" { if strings.ToLower(body) == "password" {
p.bot.Send(conn, bot.Message, message.Channel, p.bot.GetPassword()) p.bot.Send(conn, bot.Message, message.Channel, p.bot.GetPassword())
return true return true
@ -283,25 +355,43 @@ func (p *AdminPlugin) handleWebAPI(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", j) fmt.Fprintf(w, "%s", j)
} }
func (p *AdminPlugin) addWhitelist(plugin string) error {
return p.modList(`insert or replace into pluginWhitelist values (?)`, "", plugin)
}
func (p *AdminPlugin) rmWhitelist(plugin string) error {
if plugin == "admin" {
return fmt.Errorf("you cannot disable the admin plugin")
}
return p.modList(`delete from pluginWhitelist where name=?`, "", plugin)
}
func (p *AdminPlugin) addBlacklist(channel, plugin string) error { func (p *AdminPlugin) addBlacklist(channel, plugin string) error {
if plugin == "admin" { if plugin == "admin" {
return fmt.Errorf("you cannot disable the admin plugin") return fmt.Errorf("you cannot disable the admin plugin")
} }
return p.modBlacklist(`insert or replace into pluginBlacklist values (?, ?)`, channel, plugin) return p.modList(`insert or replace into pluginBlacklist values (?, ?)`, channel, plugin)
} }
func (p *AdminPlugin) rmBlacklist(channel, plugin string) error { func (p *AdminPlugin) rmBlacklist(channel, plugin string) error {
return p.modBlacklist(`delete from pluginBlacklist where channel=? and name=?`, channel, plugin) return p.modList(`delete from pluginBlacklist where channel=? and name=?`, channel, plugin)
} }
func (p *AdminPlugin) modBlacklist(query, channel, plugin string) error { func (p *AdminPlugin) modList(query, channel, plugin string) error {
if channel == "" && plugin != "" {
channel = plugin // hack
}
plugins := p.bot.GetPluginNames() plugins := p.bot.GetPluginNames()
for _, pp := range plugins { for _, pp := range plugins {
if pp == plugin { if pp == plugin {
if _, err := p.db.Exec(query, channel, plugin); err != nil { if _, err := p.db.Exec(query, channel, plugin); err != nil {
return fmt.Errorf("%w", err) return fmt.Errorf("%w", err)
} }
err := p.bot.RefreshPluginBlacklist() err := p.bot.RefreshPluginWhitelist()
if err != nil {
return fmt.Errorf("%w", err)
}
err = p.bot.RefreshPluginBlacklist()
if err != nil { if err != nil {
return fmt.Errorf("%w", err) return fmt.Errorf("%w", err)
} }