diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index eb33bf1..68fb48a 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -110,7 +110,12 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { saidSomething := false saidWhat := "" - if numTokens >= 2 && tokens[1] == "says" { + if numTokens > 2 && tokens[1] == "says" && strings.Contains(lowercase, "syas") { + split := strings.Split(lowercase, "syas") + start := strings.Fields(split[0]) + end := strings.Fields(split[1]) + saidWhat, saidSomething = p.getBabbleWithBookends(start, end) + } else if numTokens >= 2 && tokens[1] == "says" { saidWhat, saidSomething = p.getBabble(tokens) } else if numTokens > 2 && tokens[1] == "syas" { saidWhat, saidSomething = p.getBabbleWithSuffix(tokens) @@ -553,6 +558,30 @@ func (p *BabblerPlugin) getWeightedPreviousWord(toNode *BabblerNode) (*BabblerNo return nil, nil, false, errors.New("failed to find weighted previous word") } +func (p *BabblerPlugin) verifyPhrase(babbler *Babbler, phrase []string) (*BabblerNode, *BabblerNode, error) { + curNode, err := p.getBabblerNode(babbler, phrase[0]) + if err != nil { + log.Print(err) + return nil, nil, err + } + firstNode := curNode + for i := 1; i < len(phrase); i++ { + nextNode, err := p.getBabblerNode(babbler, phrase[i]) + if err != nil { + log.Print(err) + return nil, nil, err + } + _, err = p.getBabblerArc(curNode, nextNode) + if err != nil { + log.Print(err) + return nil, nil, err + } + curNode = nextNode + } + + return firstNode, curNode, nil +} + func (p *BabblerPlugin) babble(who string) (string, error) { return p.babbleSeed(who, []string{}) } @@ -576,24 +605,11 @@ func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, e } words = append(words, curWord.Word) } else { - curNode, err = p.getBabblerNode(babbler, seed[0]) + _, curNode, err = p.verifyPhrase(babbler, seed) if err != nil { log.Print(err) return "", err } - for i := 1; i < len(seed); i++ { - nextNode, err := p.getBabblerNode(babbler, seed[i]) - if err != nil { - log.Print(err) - return "", err - } - _, err = p.getBabblerArc(curNode, nextNode) - if err != nil { - log.Print(err) - return "", err - } - curNode = nextNode - } } for { @@ -720,7 +736,6 @@ func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoNa return err } - func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) { babbler, err := p.getBabbler(babblerName) if err != nil { @@ -728,30 +743,11 @@ func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (str return "", nil } - var curNode *BabblerNode - var firstNode *BabblerNode - - //this is weird, but let's first make sure that the suffix exists "in order" - //because it's already writen and we need to do it anyway... - curNode, err = p.getBabblerNode(babbler, seed[0]) + firstNode, curNode, err := p.verifyPhrase(babbler, seed) if err != nil { log.Print(err) return "", err } - firstNode = curNode - for i := 1; i < len(seed); i++ { - nextNode, err := p.getBabblerNode(babbler, seed[i]) - if err != nil { - log.Print(err) - return "", err - } - _, err = p.getBabblerArc(curNode, nextNode) - if err != nil { - log.Print(err) - return "", err - } - curNode = nextNode - } words := []string{} var curWord *BabblerWord @@ -784,3 +780,136 @@ func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (str return strings.TrimSpace(strings.Join(words, " ")), nil } + +func (p *BabblerPlugin) getNextArcs(babblerNodeId int64) ([]*BabblerArc, error) { + arcs := []*BabblerArc{} + rows, err := p.db.Queryx(`select * from babblerArcs where fromNodeId = ?;`, babblerNodeId) + if err != nil { + log.Print(err) + return arcs, err + } + defer rows.Close() + + for rows.Next() { + var arc BabblerArc + err = rows.StructScan(&arc) + if err != nil { + log.Print(err) + return []*BabblerArc{}, err + } + arcs = append(arcs, &arc) + } + return arcs, nil +} + +func (p *BabblerPlugin) getBabblerNodeById(nodeId int64) (*BabblerNode, error) { + var node BabblerNode + err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, nodeId).StructScan(&node) + if err != nil { + log.Print(err) + return nil, err + } + return &node, nil +} + +func shuffle(a []*BabblerArc) { + for i := range a { + j := rand.Intn(i + 1) + a[i], a[j] = a[j], a[i] + } +} + +func (p *BabblerPlugin) babbleSeedBookends(babblerName string, start, end []string) (string, error) { + babbler, err := p.getBabbler(babblerName) + if err != nil { + log.Print(err) + return "", nil + } + + _, startWordNode, err := p.verifyPhrase(babbler, start) + if err != nil { + log.Print(err) + return "", err + } + + endWordNode, _, err := p.verifyPhrase(babbler, end) + if err != nil { + log.Print(err) + return "", err + } + + type searchNode struct { + babblerNodeId int64 + previous *searchNode + } + + open := []*searchNode{ &searchNode{startWordNode.NodeId, nil} } + closed := map[int64]*searchNode{ startWordNode.NodeId : open[0] } + goalNodeId := int64(-1) + + for i := 0; i < len(open) && i < 250; i++ { + cur := open[i] + + arcs, err := p.getNextArcs(cur.babblerNodeId) + if err != nil { + return "", err + } + //add a little randomization in through child ordering + shuffle(arcs) + + for _, arc := range arcs { + if _, ok := closed[arc.ToNodeId]; !ok { + child := &searchNode{arc.ToNodeId, cur} + open = append(open, child) + closed[arc.ToNodeId] = child + + if arc.ToNodeId == endWordNode.NodeId { + goalNodeId = cur.babblerNodeId + break + } + } + } + } + + if goalNodeId == -1 { + return "", errors.New("couldn't find path") + } else if closed[goalNodeId].previous == nil { + seeds := append(start, end...) + return strings.Join(seeds, " "), nil + } + + words := []string{} + + curSearchNode := closed[goalNodeId] + + for { + cur, err := p.getBabblerNodeById(curSearchNode.babblerNodeId) + if err != nil { + log.Print(err) + return "", err + } + var w BabblerWord + err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, cur.WordId).StructScan(&w) + if err != nil { + log.Print(err) + return "", err + } + words = append(words, w.Word) + + curSearchNode = closed[curSearchNode.previous.babblerNodeId] + + if curSearchNode.previous == nil { + break + } + } + + for i := 0; i < len(words) / 2; i++ { + index := len(words)-(i+1) + words[i], words[index] = words[index], words[i] + } + + words = append(start, words...) + words = append(words, end...) + + return strings.Join(words, " "), nil +} diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index 444f739..a54c625 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -195,6 +195,48 @@ func TestBabblerBadSuffixSeed(t *testing.T) { assert.Contains(t, mb.Messages[0], "seabass never said 'anything true'") } +func TestBabblerBookendSeed(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("It's easier to test with unique messages") + seabass.User = &user.User{Name: "seabass"} + res := c.Message(seabass) + res = c.Message(makeMessage("!seabass says It's easier syas unique messages")) + assert.Len(t, mb.Messages, 1) + assert.True(t, res) + assert.Contains(t, mb.Messages[0], "it's easier to test with unique messages") +} + +func TestBabblerBookendSeedShort(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("It's easier to test with unique messages") + seabass.User = &user.User{Name: "seabass"} + res := c.Message(seabass) + res = c.Message(makeMessage("!seabass says It's easier to test with syas unique messages")) + assert.Len(t, mb.Messages, 1) + assert.True(t, res) + assert.Contains(t, mb.Messages[0], "it's easier to test with unique messages") +} + +func TestBabblerBadBookendSeed(t *testing.T) { + mb := bot.NewMockBot() + c := New(mb) + c.config.Babbler.DefaultUsers = []string{"seabass"} + assert.NotNil(t, c) + seabass := makeMessage("It's easier to test with unique messages") + seabass.User = &user.User{Name: "seabass"} + res := c.Message(seabass) + res = c.Message(makeMessage("!seabass says It's easier syas not unique messages")) + assert.Len(t, mb.Messages, 1) + assert.True(t, res) + assert.Contains(t, mb.Messages[0], "seabass never said 'it's easier ... not unique messages'") +} + func TestBabblerBatch(t *testing.T) { mb := bot.NewMockBot() c := New(mb) diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go index 35ac59d..c501346 100644 --- a/plugins/babbler/commands.go +++ b/plugins/babbler/commands.go @@ -77,6 +77,33 @@ func (p *BabblerPlugin) getBabbleWithSuffix(tokens []string) (string, bool) { return "", false } +func (p *BabblerPlugin) getBabbleWithBookends(start, end []string) (string, bool) { + who := start[0] + _, err := p.getBabbler(who) + + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + } else { + + saying, err := p.babbleSeedBookends(who, start[2:], end) + + if err != nil { + if err == SAID_NOTHING { + return fmt.Sprintf("%s hasn't said anything yet.", who), true + } else if err == NEVER_SAID { + seeds := append(start[2:], "...") + seeds = append(seeds, end...) + return fmt.Sprintf("%s never said '%s'", who, strings.Join(seeds, " ")), true + } + } else if saying != "" { + return saying, true + } + } + return "", false +} + func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) { who := tokens[3] babblerId, err := p.getOrCreateBabbler(who)