mirror of https://github.com/velour/catbase.git
allow babbles to formed with suffixes (closes #55)
This commit is contained in:
parent
0f94b6203e
commit
aaba7c5bff
|
@ -112,6 +112,8 @@ func (p *BabblerPlugin) Message(message msg.Message) bool {
|
||||||
|
|
||||||
if numTokens >= 2 && tokens[1] == "says" {
|
if numTokens >= 2 && tokens[1] == "says" {
|
||||||
saidWhat, saidSomething = p.getBabble(tokens)
|
saidWhat, saidSomething = p.getBabble(tokens)
|
||||||
|
} else if numTokens > 2 && tokens[1] == "syas" {
|
||||||
|
saidWhat, saidSomething = p.getBabbleWithSuffix(tokens)
|
||||||
} else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 {
|
} else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 {
|
||||||
saidWhat, saidSomething = p.initializeBabbler(tokens)
|
saidWhat, saidSomething = p.initializeBabbler(tokens)
|
||||||
} else if strings.Index(lowercase, "batch learn for ") == 0 {
|
} else if strings.Index(lowercase, "batch learn for ") == 0 {
|
||||||
|
@ -493,6 +495,64 @@ func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode
|
||||||
return nil, nil, errors.New("failed to find weighted next word")
|
return nil, nil, errors.New("failed to find weighted next word")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getWeightedPreviousWord(toNode *BabblerNode) (*BabblerNode, *BabblerWord, bool, error) {
|
||||||
|
rows, err := p.db.Queryx(`select * from babblerArcs where toNodeId = ?;`, toNode.NodeId)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
arcs := []*BabblerArc{}
|
||||||
|
total := int64(0)
|
||||||
|
for rows.Next() {
|
||||||
|
var arc BabblerArc
|
||||||
|
err = rows.StructScan(&arc)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, false, err
|
||||||
|
}
|
||||||
|
arcs = append(arcs, &arc)
|
||||||
|
total += arc.Frequency
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(arcs) == 0 {
|
||||||
|
return nil, nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
which := rand.Int63n(total + toNode.RootFrequency)
|
||||||
|
|
||||||
|
//terminate the babble
|
||||||
|
if which >= total {
|
||||||
|
return nil, nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for _, arc := range arcs {
|
||||||
|
|
||||||
|
total += arc.Frequency
|
||||||
|
|
||||||
|
if total >= which {
|
||||||
|
var node BabblerNode
|
||||||
|
err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, arc.FromNodeId).StructScan(&node)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var w BabblerWord
|
||||||
|
err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil, nil, false, err
|
||||||
|
}
|
||||||
|
return &node, &w, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Fatalf("shouldn't happen")
|
||||||
|
return nil, nil, false, errors.New("failed to find weighted previous word")
|
||||||
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) babble(who string) (string, error) {
|
func (p *BabblerPlugin) babble(who string) (string, error) {
|
||||||
return p.babbleSeed(who, []string{})
|
return p.babbleSeed(who, []string{})
|
||||||
}
|
}
|
||||||
|
@ -659,3 +719,68 @@ func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoNa
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) {
|
||||||
|
babbler, err := p.getBabbler(babblerName)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var curNode *BabblerNode
|
||||||
|
var firstNode *BabblerNode
|
||||||
|
|
||||||
|
//this is weird, but let's first make sure that the suffix exists "in order"
|
||||||
|
//because it's already writen and we need to do it anyway...
|
||||||
|
curNode, err = p.getBabblerNode(babbler, seed[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
firstNode = curNode
|
||||||
|
for i := 1; i < len(seed); i++ {
|
||||||
|
nextNode, err := p.getBabblerNode(babbler, seed[i])
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
_, err = p.getBabblerArc(curNode, nextNode)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
curNode = nextNode
|
||||||
|
}
|
||||||
|
|
||||||
|
words := []string{}
|
||||||
|
var curWord *BabblerWord
|
||||||
|
var shouldTerminate bool
|
||||||
|
curNode = firstNode
|
||||||
|
for {
|
||||||
|
curNode, curWord, shouldTerminate, err = p.getWeightedPreviousWord(curNode)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldTerminate {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
words = append(words, curWord.Word)
|
||||||
|
|
||||||
|
if len(words) >= 250 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(words) / 2; i++ {
|
||||||
|
index := len(words)-(i+1)
|
||||||
|
words[i], words[index] = words[index], words[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
words = append(words, seed...)
|
||||||
|
|
||||||
|
return strings.TrimSpace(strings.Join(words, " ")), nil
|
||||||
|
}
|
||||||
|
|
|
@ -157,6 +157,44 @@ func TestBabblerBadSeed2(t *testing.T) {
|
||||||
assert.Contains(t, mb.Messages[0], "seabass never said 'this is a really'")
|
assert.Contains(t, mb.Messages[0], "seabass never said 'this is a really'")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBabblerSuffixSeed(t *testing.T) {
|
||||||
|
mb := bot.NewMockBot()
|
||||||
|
c := New(mb)
|
||||||
|
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
seabass := makeMessage("This is message one")
|
||||||
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
|
res := c.Message(seabass)
|
||||||
|
seabass.Body = "It's easier to test with unique messages"
|
||||||
|
res = c.Message(seabass)
|
||||||
|
seabass.Body = "hi there"
|
||||||
|
res = c.Message(seabass)
|
||||||
|
res = c.Message(makeMessage("!seabass syas message one"))
|
||||||
|
res = c.Message(makeMessage("!seabass syas with unique"))
|
||||||
|
assert.Len(t, mb.Messages, 2)
|
||||||
|
assert.True(t, res)
|
||||||
|
assert.Contains(t, mb.Messages[0], "this is message one")
|
||||||
|
assert.Contains(t, mb.Messages[1], "it's easier to test with unique")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBabblerBadSuffixSeed(t *testing.T) {
|
||||||
|
mb := bot.NewMockBot()
|
||||||
|
c := New(mb)
|
||||||
|
c.config.Babbler.DefaultUsers = []string{"seabass"}
|
||||||
|
assert.NotNil(t, c)
|
||||||
|
seabass := makeMessage("This is message one")
|
||||||
|
seabass.User = &user.User{Name: "seabass"}
|
||||||
|
res := c.Message(seabass)
|
||||||
|
seabass.Body = "It's easier to test with unique messages"
|
||||||
|
res = c.Message(seabass)
|
||||||
|
seabass.Body = "hi there"
|
||||||
|
res = c.Message(seabass)
|
||||||
|
res = c.Message(makeMessage("!seabass syas anything true"))
|
||||||
|
assert.Len(t, mb.Messages, 1)
|
||||||
|
assert.True(t, res)
|
||||||
|
assert.Contains(t, mb.Messages[0], "seabass never said 'anything true'")
|
||||||
|
}
|
||||||
|
|
||||||
func TestBabblerBatch(t *testing.T) {
|
func TestBabblerBatch(t *testing.T) {
|
||||||
mb := bot.NewMockBot()
|
mb := bot.NewMockBot()
|
||||||
c := New(mb)
|
c := New(mb)
|
||||||
|
|
|
@ -52,6 +52,31 @@ func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *BabblerPlugin) getBabbleWithSuffix(tokens []string) (string, bool) {
|
||||||
|
who := tokens[0]
|
||||||
|
_, err := p.getBabbler(who)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == NO_BABBLER {
|
||||||
|
return fmt.Sprintf("%s babbler not found.", who), true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
saying, err := p.babbleSeedSuffix(who, tokens[2:])
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == SAID_NOTHING {
|
||||||
|
return fmt.Sprintf("%s hasn't said anything yet.", who), true
|
||||||
|
} else if err == NEVER_SAID {
|
||||||
|
return fmt.Sprintf("%s never said '%s'", who, strings.Join(tokens[2:], " ")), true
|
||||||
|
}
|
||||||
|
} else if saying != "" {
|
||||||
|
return saying, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) {
|
func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) {
|
||||||
who := tokens[3]
|
who := tokens[3]
|
||||||
babblerId, err := p.getOrCreateBabbler(who)
|
babblerId, err := p.getOrCreateBabbler(who)
|
||||||
|
|
Loading…
Reference in New Issue