diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 685bd04..222df25 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -3,7 +3,6 @@ package babbler import ( - // "database/sql" "fmt" "math/rand" "strings" @@ -80,10 +79,15 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { lowercase := strings.ToLower(message.Body) tokens := strings.Fields(lowercase) - addToMarkovChain(p.babblers[message.User.Name], lowercase) - - if len(tokens) == 4 && strings.Contains(lowercase, "initialize babbler for ") { - who := tokens[len(tokens)-1] + if len(tokens) == 2 && tokens[1] == "says" { + saying := p.babble(tokens[0]) + if saying == "" { + p.Bot.SendMessage(message.Channel, "Ze ain't said nothin'") + } + p.Bot.SendMessage(message.Channel, saying) + return true + } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { + who := tokens[3] if _, ok := p.babblers[who]; !ok { babbler, err := getMarkovChain(p.db, who) if err == nil { @@ -94,16 +98,35 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { p.Bot.SendMessage(message.Channel, "Okay.") return true } + } else if strings.Index(lowercase, "batch learn for ") == 0 { + who := tokens[3] + if _, ok := p.babblers[who]; !ok { + p.babblers[who] = newBabbler() + } + + body := strings.Join(tokens[4:], " ") + body = strings.ToLower(body) + + for _, a := range strings.Split(body, ".") { + for _, b := range strings.Split(a, "!") { + for _, c := range strings.Split(b, "?") { + for _, d := range strings.Split(c, "\n") { + trimmed := strings.TrimSpace(d) + if trimmed != "" { + addToMarkovChain(p.babblers[who], trimmed) + } + } + } + } + } + + p.Bot.SendMessage(message.Channel, "Phew that was tiring.") + return true + } else { + addToMarkovChain(p.babblers[message.User.Name], lowercase) } - if len(tokens) == 2 && tokens[1] == "says" { - saying := p.babble(tokens[0]) - if saying == "" { - p.Bot.SendMessage(message.Channel, "Ze ain't said nothin'") - } - p.Bot.SendMessage(message.Channel, saying) - return true - } + return false } diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index aabbd7d..62b517b 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -45,6 +45,22 @@ func TestBabbler(t *testing.T) { assert.Contains(t, mb.Messages[0], "message") } +func TestBabblerBatch(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("batch learn for seabass This is a message! This is another message. This is not a long message? This is not a message! This is not another message. This is a long message?") + res := c.Message(seabass) + assert.Len(t, c.babblers, 2) + assert.Len(t, mb.Messages, 1) + res = c.Message(makeMessage("!seabass says")) + assert.Len(t, mb.Messages, 2) + assert.True(t, res) + assert.Contains(t, mb.Messages[1], "this is") + assert.Contains(t, mb.Messages[1], "message") +} + func TestHelp(t *testing.T) { mb := bot.NewMockBot() c := New(mb)