mirror of https://github.com/velour/catbase.git
Merge pull request #59 from velour/babble_bookend
Even more custom babbling... with bookends! closed #58
This commit is contained in:
commit
98298dfb84
|
@ -110,7 +110,12 @@ func (p *BabblerPlugin) Message(message msg.Message) bool {
|
||||||
saidSomething := false
|
saidSomething := false
|
||||||
saidWhat := ""
|
saidWhat := ""
|
||||||
|
|
||||||
if numTokens >= 2 && tokens[1] == "says" {
|
if numTokens > 2 && tokens[1] == "says" && strings.Contains(lowercase, "syas") {
|
||||||
|
split := strings.Split(lowercase, "syas")
|
||||||
|
start := strings.Fields(split[0])
|
||||||
|
end := strings.Fields(split[1])
|
||||||
|
saidWhat, saidSomething = p.getBabbleWithBookends(start, end)
|
||||||
|
} else if numTokens >= 2 && tokens[1] == "says" {
|
||||||
saidWhat, saidSomething = p.getBabble(tokens)
|
saidWhat, saidSomething = p.getBabble(tokens)
|
||||||
} else if numTokens > 2 && tokens[1] == "syas" {
|
} else if numTokens > 2 && tokens[1] == "syas" {
|
||||||
saidWhat, saidSomething = p.getBabbleWithSuffix(tokens)
|
saidWhat, saidSomething = p.getBabbleWithSuffix(tokens)
|
||||||
|
@ -553,6 +558,30 @@ func (p *BabblerPlugin) getWeightedPreviousWord(toNode *BabblerNode) (*BabblerNo
|
||||||
return nil, nil, false, errors.New("failed to find weighted previous word")
|
return nil, nil, false, errors.New("failed to find weighted previous word")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) verifyPhrase(babbler *Babbler, phrase []string) (*BabblerNode, *BabblerNode, error) {
|
||||||
|
curNode, err := p.getBabblerNode(babbler, phrase[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
firstNode := curNode
|
||||||
|
for i := 1; i < len(phrase); i++ {
|
||||||
|
nextNode, err := p.getBabblerNode(babbler, phrase[i])
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
_, err = p.getBabblerArc(curNode, nextNode)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
curNode = nextNode
|
||||||
|
}
|
||||||
|
|
||||||
|
return firstNode, curNode, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) babble(who string) (string, error) {
|
func (p *BabblerPlugin) babble(who string) (string, error) {
|
||||||
return p.babbleSeed(who, []string{})
|
return p.babbleSeed(who, []string{})
|
||||||
}
|
}
|
||||||
|
@ -576,24 +605,11 @@ func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, e
|
||||||
}
|
}
|
||||||
words = append(words, curWord.Word)
|
words = append(words, curWord.Word)
|
||||||
} else {
|
} else {
|
||||||
curNode, err = p.getBabblerNode(babbler, seed[0])
|
_, curNode, err = p.verifyPhrase(babbler, seed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
for i := 1; i < len(seed); i++ {
|
|
||||||
nextNode, err := p.getBabblerNode(babbler, seed[i])
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
_, err = p.getBabblerArc(curNode, nextNode)
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
curNode = nextNode
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -720,7 +736,6 @@ func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoNa
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) {
|
func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) {
|
||||||
babbler, err := p.getBabbler(babblerName)
|
babbler, err := p.getBabbler(babblerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -728,30 +743,11 @@ func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (str
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var curNode *BabblerNode
|
firstNode, curNode, err := p.verifyPhrase(babbler, seed)
|
||||||
var firstNode *BabblerNode
|
|
||||||
|
|
||||||
//this is weird, but let's first make sure that the suffix exists "in order"
|
|
||||||
//because it's already writen and we need to do it anyway...
|
|
||||||
curNode, err = p.getBabblerNode(babbler, seed[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
firstNode = curNode
|
|
||||||
for i := 1; i < len(seed); i++ {
|
|
||||||
nextNode, err := p.getBabblerNode(babbler, seed[i])
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
_, err = p.getBabblerArc(curNode, nextNode)
|
|
||||||
if err != nil {
|
|
||||||
log.Print(err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
curNode = nextNode
|
|
||||||
}
|
|
||||||
|
|
||||||
words := []string{}
|
words := []string{}
|
||||||
var curWord *BabblerWord
|
var curWord *BabblerWord
|
||||||
|
@ -784,3 +780,136 @@ func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (str
|
||||||
|
|
||||||
return strings.TrimSpace(strings.Join(words, " ")), nil
|
return strings.TrimSpace(strings.Join(words, " ")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getNextArcs(babblerNodeId int64) ([]*BabblerArc, error) {
|
||||||
|
arcs := []*BabblerArc{}
|
||||||
|
rows, err := p.db.Queryx(`select * from babblerArcs where fromNodeId = ?;`, babblerNodeId)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return arcs, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var arc BabblerArc
|
||||||
|
err = rows.StructScan(&arc)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return []*BabblerArc{}, err
|
||||||
|
}
|
||||||
|
arcs = append(arcs, &arc)
|
||||||
|
}
|
||||||
|
return arcs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getBabblerNodeById(nodeId int64) (*BabblerNode, error) {
|
||||||
|
var node BabblerNode
|
||||||
|
err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, nodeId).StructScan(&node)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shuffle(a []*BabblerArc) {
|
||||||
|
for i := range a {
|
||||||
|
j := rand.Intn(i + 1)
|
||||||
|
a[i], a[j] = a[j], a[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) babbleSeedBookends(babblerName string, start, end []string) (string, error) {
|
||||||
|
babbler, err := p.getBabbler(babblerName)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, startWordNode, err := p.verifyPhrase(babbler, start)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
endWordNode, _, err := p.verifyPhrase(babbler, end)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
type searchNode struct {
|
||||||
|
babblerNodeId int64
|
||||||
|
previous *searchNode
|
||||||
|
}
|
||||||
|
|
||||||
|
open := []*searchNode{ &searchNode{startWordNode.NodeId, nil} }
|
||||||
|
closed := map[int64]*searchNode{ startWordNode.NodeId : open[0] }
|
||||||
|
goalNodeId := int64(-1)
|
||||||
|
|
||||||
|
for i := 0; i < len(open) && i < 250; i++ {
|
||||||
|
cur := open[i]
|
||||||
|
|
||||||
|
arcs, err := p.getNextArcs(cur.babblerNodeId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
//add a little randomization in through child ordering
|
||||||
|
shuffle(arcs)
|
||||||
|
|
||||||
|
for _, arc := range arcs {
|
||||||
|
if _, ok := closed[arc.ToNodeId]; !ok {
|
||||||
|
child := &searchNode{arc.ToNodeId, cur}
|
||||||
|
open = append(open, child)
|
||||||
|
closed[arc.ToNodeId] = child
|
||||||
|
|
||||||
|
if arc.ToNodeId == endWordNode.NodeId {
|
||||||
|
goalNodeId = cur.babblerNodeId
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if goalNodeId == -1 {
|
||||||
|
return "", errors.New("couldn't find path")
|
||||||
|
} else if closed[goalNodeId].previous == nil {
|
||||||
|
seeds := append(start, end...)
|
||||||
|
return strings.Join(seeds, " "), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
words := []string{}
|
||||||
|
|
||||||
|
curSearchNode := closed[goalNodeId]
|
||||||
|
|
||||||
|
for {
|
||||||
|
cur, err := p.getBabblerNodeById(curSearchNode.babblerNodeId)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var w BabblerWord
|
||||||
|
err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, cur.WordId).StructScan(&w)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
words = append(words, w.Word)
|
||||||
|
|
||||||
|
curSearchNode = closed[curSearchNode.previous.babblerNodeId]
|
||||||
|
|
||||||
|
if curSearchNode.previous == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(words) / 2; i++ {
|
||||||
|
index := len(words)-(i+1)
|
||||||
|
words[i], words[index] = words[index], words[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
words = append(start, words...)
|
||||||
|
words = append(words, end...)
|
||||||
|
|
||||||
|
return strings.Join(words, " "), nil
|
||||||
|
}
|
||||||
|
|
|
@ -195,6 +195,48 @@ func TestBabblerBadSuffixSeed(t *testing.T) {
|
||||||
assert.Contains(t, mb.Messages[0], "seabass never said 'anything true'")
|
assert.Contains(t, mb.Messages[0], "seabass never said 'anything true'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBabblerBookendSeed(t *testing.T) {
|
||||||
|
mb := bot.NewMockBot()
|
||||||
|
c := New(mb)
|
||||||
|
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
seabass := makeMessage("It's easier to test with unique messages")
|
||||||
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
|
res := c.Message(seabass)
|
||||||
|
res = c.Message(makeMessage("!seabass says It's easier syas unique messages"))
|
||||||
|
assert.Len(t, mb.Messages, 1)
|
||||||
|
assert.True(t, res)
|
||||||
|
assert.Contains(t, mb.Messages[0], "it's easier to test with unique messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBabblerBookendSeedShort(t *testing.T) {
|
||||||
|
mb := bot.NewMockBot()
|
||||||
|
c := New(mb)
|
||||||
|
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
seabass := makeMessage("It's easier to test with unique messages")
|
||||||
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
|
res := c.Message(seabass)
|
||||||
|
res = c.Message(makeMessage("!seabass says It's easier to test with syas unique messages"))
|
||||||
|
assert.Len(t, mb.Messages, 1)
|
||||||
|
assert.True(t, res)
|
||||||
|
assert.Contains(t, mb.Messages[0], "it's easier to test with unique messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBabblerBadBookendSeed(t *testing.T) {
|
||||||
|
mb := bot.NewMockBot()
|
||||||
|
c := New(mb)
|
||||||
|
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
seabass := makeMessage("It's easier to test with unique messages")
|
||||||
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
|
res := c.Message(seabass)
|
||||||
|
res = c.Message(makeMessage("!seabass says It's easier syas not unique messages"))
|
||||||
|
assert.Len(t, mb.Messages, 1)
|
||||||
|
assert.True(t, res)
|
||||||
|
assert.Contains(t, mb.Messages[0], "seabass never said 'it's easier ... not unique messages'")
|
||||||
|
}
|
||||||
|
|
||||||
func TestBabblerBatch(t *testing.T) {
|
func TestBabblerBatch(t *testing.T) {
|
||||||
mb := bot.NewMockBot()
|
mb := bot.NewMockBot()
|
||||||
c := New(mb)
|
c := New(mb)
|
||||||
|
|
|
@ -77,6 +77,33 @@ func (p *BabblerPlugin) getBabbleWithSuffix(tokens []string) (string, bool) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getBabbleWithBookends(start, end []string) (string, bool) {
|
||||||
|
who := start[0]
|
||||||
|
_, err := p.getBabbler(who)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == NO_BABBLER {
|
||||||
|
return fmt.Sprintf("%s babbler not found.", who), true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
saying, err := p.babbleSeedBookends(who, start[2:], end)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == SAID_NOTHING {
|
||||||
|
return fmt.Sprintf("%s hasn't said anything yet.", who), true
|
||||||
|
} else if err == NEVER_SAID {
|
||||||
|
seeds := append(start[2:], "...")
|
||||||
|
seeds = append(seeds, end...)
|
||||||
|
return fmt.Sprintf("%s never said '%s'", who, strings.Join(seeds, " ")), true
|
||||||
|
}
|
||||||
|
} else if saying != "" {
|
||||||
|
return saying, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) {
|
func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) {
|
||||||
who := tokens[3]
|
who := tokens[3]
|
||||||
babblerId, err := p.getOrCreateBabbler(who)
|
babblerId, err := p.getOrCreateBabbler(who)
|
||||||
|
|
Loading…
Reference in New Issue