From b59cd6441f8edc4c79c7f0d4a0715637f69e4845 Mon Sep 17 00:00:00 2001 From: skiesel Date: Wed, 11 May 2016 13:07:16 -0400 Subject: [PATCH] batch learning for the babbler --- plugins/babbler/babbler.go | 29 ++++++++++++++++++++++++++--- plugins/babbler/babbler_test.go | 16 ++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 3405a91..222df25 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -3,7 +3,6 @@ package babbler import ( - // "database/sql" "fmt" "math/rand" "strings" @@ -87,8 +86,8 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { } p.Bot.SendMessage(message.Channel, saying) return true - } else if len(tokens) == 4 && strings.Contains(lowercase, "initialize babbler for ") { - who := tokens[len(tokens)-1] + } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { + who := tokens[3] if _, ok := p.babblers[who]; !ok { babbler, err := getMarkovChain(p.db, who) if err == nil { @@ -99,6 +98,30 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { p.Bot.SendMessage(message.Channel, "Okay.") return true } + } else if strings.Index(lowercase, "batch learn for ") == 0 { + who := tokens[3] + if _, ok := p.babblers[who]; !ok { + p.babblers[who] = newBabbler() + } + + body := strings.Join(tokens[4:], " ") + body = strings.ToLower(body) + + for _, a := range strings.Split(body, ".") { + for _, b := range strings.Split(a, "!") { + for _, c := range strings.Split(b, "?") { + for _, d := range strings.Split(c, "\n") { + trimmed := strings.TrimSpace(d) + if trimmed != "" { + addToMarkovChain(p.babblers[who], trimmed) + } + } + } + } + } + + p.Bot.SendMessage(message.Channel, "Phew that was tiring.") + return true } else { addToMarkovChain(p.babblers[message.User.Name], lowercase) } diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index aabbd7d..62b517b 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -45,6 +45,22 @@ func TestBabbler(t *testing.T) { assert.Contains(t, mb.Messages[0], "message") } +func TestBabblerBatch(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("batch learn for seabass This is a message! This is another message. This is not a long message? This is not a message! This is not another message. This is a long message?") + res := c.Message(seabass) + assert.Len(t, c.babblers, 2) + assert.Len(t, mb.Messages, 1) + res = c.Message(makeMessage("!seabass says")) + assert.Len(t, mb.Messages, 2) + assert.True(t, res) + assert.Contains(t, mb.Messages[1], "this is") + assert.Contains(t, mb.Messages[1], "message") +} + func TestHelp(t *testing.T) { mb := bot.NewMockBot() c := New(mb)