From aaba7c5bff504e85d49323a32e9d53a7e4297bc8 Mon Sep 17 00:00:00 2001 From: skkiesel Date: Mon, 5 Jun 2017 19:53:19 -0400 Subject: [PATCH] allow babbles to formed with suffixes (closes #55) --- plugins/babbler/babbler.go | 125 ++++++++++++++++++++++++++++++++ plugins/babbler/babbler_test.go | 38 ++++++++++ plugins/babbler/commands.go | 25 +++++++ 3 files changed, 188 insertions(+) diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 357796a..eb33bf1 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -112,6 +112,8 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { if numTokens >= 2 && tokens[1] == "says" { saidWhat, saidSomething = p.getBabble(tokens) + } else if numTokens > 2 && tokens[1] == "syas" { + saidWhat, saidSomething = p.getBabbleWithSuffix(tokens) } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { saidWhat, saidSomething = p.initializeBabbler(tokens) } else if strings.Index(lowercase, "batch learn for ") == 0 { @@ -493,6 +495,64 @@ func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode return nil, nil, errors.New("failed to find weighted next word") } +func (p *BabblerPlugin) getWeightedPreviousWord(toNode *BabblerNode) (*BabblerNode, *BabblerWord, bool, error) { + rows, err := p.db.Queryx(`select * from babblerArcs where toNodeId = ?;`, toNode.NodeId) + if err != nil { + log.Print(err) + return nil, nil, false, err + } + defer rows.Close() + + arcs := []*BabblerArc{} + total := int64(0) + for rows.Next() { + var arc BabblerArc + err = rows.StructScan(&arc) + if err != nil { + log.Print(err) + return nil, nil, false, err + } + arcs = append(arcs, &arc) + total += arc.Frequency + } + + if len(arcs) == 0 { + return nil, nil, true, nil + } + + which := rand.Int63n(total + toNode.RootFrequency) + + //terminate the babble + if which >= total { + return nil, nil, true, nil + } + + total = 0 + for _, arc := range arcs { + + total += arc.Frequency + + if total >= which { + var node BabblerNode + err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, arc.FromNodeId).StructScan(&node) + if err != nil { + log.Print(err) + return nil, nil, false, err + } + + var w BabblerWord + err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w) + if err != nil { + log.Print(err) + return nil, nil, false, err + } + return &node, &w, false, nil + } + } + log.Fatalf("shouldn't happen") + return nil, nil, false, errors.New("failed to find weighted previous word") +} + func (p *BabblerPlugin) babble(who string) (string, error) { return p.babbleSeed(who, []string{}) } @@ -659,3 +719,68 @@ func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoNa return err } + + +func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) { + babbler, err := p.getBabbler(babblerName) + if err != nil { + log.Print(err) + return "", nil + } + + var curNode *BabblerNode + var firstNode *BabblerNode + + //this is weird, but let's first make sure that the suffix exists "in order" + //because it's already writen and we need to do it anyway... + curNode, err = p.getBabblerNode(babbler, seed[0]) + if err != nil { + log.Print(err) + return "", err + } + firstNode = curNode + for i := 1; i < len(seed); i++ { + nextNode, err := p.getBabblerNode(babbler, seed[i]) + if err != nil { + log.Print(err) + return "", err + } + _, err = p.getBabblerArc(curNode, nextNode) + if err != nil { + log.Print(err) + return "", err + } + curNode = nextNode + } + + words := []string{} + var curWord *BabblerWord + var shouldTerminate bool + curNode = firstNode + for { + curNode, curWord, shouldTerminate, err = p.getWeightedPreviousWord(curNode) + if err != nil { + log.Print(err) + return "", err + } + + if shouldTerminate { + break + } + + words = append(words, curWord.Word) + + if len(words) >= 250 { + break + } + } + + for i := 0; i < len(words) / 2; i++ { + index := len(words)-(i+1) + words[i], words[index] = words[index], words[i] + } + + words = append(words, seed...) + + return strings.TrimSpace(strings.Join(words, " ")), nil +} diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index a568b8b..444f739 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -157,6 +157,44 @@ func TestBabblerBadSeed2(t *testing.T) { assert.Contains(t, mb.Messages[0], "seabass never said 'this is a really'") } +func TestBabblerSuffixSeed(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("This is message one") + seabass.User = &user.User{Name: "seabass"} + res := c.Message(seabass) + seabass.Body = "It's easier to test with unique messages" + res = c.Message(seabass) + seabass.Body = "hi there" + res = c.Message(seabass) + res = c.Message(makeMessage("!seabass syas message one")) + res = c.Message(makeMessage("!seabass syas with unique")) + assert.Len(t, mb.Messages, 2) + assert.True(t, res) + assert.Contains(t, mb.Messages[0], "this is message one") + assert.Contains(t, mb.Messages[1], "it's easier to test with unique") +} + +func TestBabblerBadSuffixSeed(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("This is message one") + seabass.User = &user.User{Name: "seabass"} + res := c.Message(seabass) + seabass.Body = "It's easier to test with unique messages" + res = c.Message(seabass) + seabass.Body = "hi there" + res = c.Message(seabass) + res = c.Message(makeMessage("!seabass syas anything true")) + assert.Len(t, mb.Messages, 1) + assert.True(t, res) + assert.Contains(t, mb.Messages[0], "seabass never said 'anything true'") +} + func TestBabblerBatch(t *testing.T) { mb := bot.NewMockBot() c := New(mb) diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go index 38e2502..35ac59d 100644 --- a/plugins/babbler/commands.go +++ b/plugins/babbler/commands.go @@ -52,6 +52,31 @@ func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) { return "", false } +func (p *BabblerPlugin) getBabbleWithSuffix(tokens []string) (string, bool) { + who := tokens[0] + _, err := p.getBabbler(who) + + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + } else { + + saying, err := p.babbleSeedSuffix(who, tokens[2:]) + + if err != nil { + if err == SAID_NOTHING { + return fmt.Sprintf("%s hasn't said anything yet.", who), true + } else if err == NEVER_SAID { + return fmt.Sprintf("%s never said '%s'", who, strings.Join(tokens[2:], " ")), true + } + } else if saying != "" { + return saying, true + } + } + return "", false +} + func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) { who := tokens[3] babblerId, err := p.getOrCreateBabbler(who)