mirror of https://github.com/velour/catbase.git
872 lines
20 KiB
Go
872 lines
20 KiB
Go
// © 2013 the CatBase Authors under the WTFPL. See AUTHORS for the list of authors.
|
|
|
|
package babbler
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
bh "github.com/timshannon/bolthold"
|
|
"github.com/velour/catbase/plugins/remember"
|
|
"math/rand"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
|
|
"github.com/velour/catbase/bot"
|
|
"github.com/velour/catbase/bot/msg"
|
|
)
|
|
|
|
var (
|
|
NO_BABBLER = errors.New("babbler not found")
|
|
SAID_NOTHING = errors.New("hasn't said anything yet")
|
|
NEVER_SAID = errors.New("never said that")
|
|
)
|
|
|
|
type BabblerPlugin struct {
|
|
Bot bot.Bot
|
|
store *bh.Store
|
|
WithGoRoutines bool
|
|
|
|
handlers bot.HandlerTable
|
|
}
|
|
|
|
type Babbler struct {
|
|
BabblerId int64 `db:"id" boltholdid:"BabblerId"`
|
|
Name string `db:"babbler"`
|
|
}
|
|
|
|
func getBabbler(store *bh.Store, id int64) (*Babbler, error) {
|
|
res := &Babbler{}
|
|
err := store.Get(id, res)
|
|
return res, err
|
|
}
|
|
|
|
type BabblerWord struct {
|
|
WordId int64 `db:"id" boltholdid:"WordId"`
|
|
Word string `db:"word"`
|
|
}
|
|
|
|
func getWord(store *bh.Store, id int64) (*BabblerWord, error) {
|
|
res := &BabblerWord{}
|
|
err := store.Get(id, res)
|
|
return res, err
|
|
}
|
|
|
|
type BabblerNode struct {
|
|
NodeId int64 `db:"id" boltholdid:"NodeId"`
|
|
BabblerId int64 `db:"babblerId"`
|
|
WordId int64 `db:"wordId"`
|
|
Root int64 `db:"root"`
|
|
RootFrequency int64 `db:"rootFrequency"`
|
|
}
|
|
|
|
func getNode(store *bh.Store, id int64) (*BabblerNode, error) {
|
|
res := &BabblerNode{}
|
|
err := store.Get(id, res)
|
|
return res, err
|
|
}
|
|
|
|
type BabblerArc struct {
|
|
ArcId int64 `db:"id" boltholdid:"ArcId"`
|
|
FromNodeId int64 `db:"fromNodeId"`
|
|
ToNodeId int64 `db:"toNodeId"`
|
|
Frequency int64 `db:"frequency"`
|
|
}
|
|
|
|
func getArc(store *bh.Store, id int64) (*BabblerArc, error) {
|
|
res := &BabblerArc{}
|
|
err := store.Get(id, res)
|
|
return res, err
|
|
}
|
|
|
|
func New(b bot.Bot) *BabblerPlugin {
|
|
plugin := &BabblerPlugin{
|
|
Bot: b,
|
|
store: b.Store(),
|
|
WithGoRoutines: true,
|
|
}
|
|
|
|
plugin.createNewWord("")
|
|
|
|
plugin.register()
|
|
|
|
return plugin
|
|
}
|
|
|
|
func (p *BabblerPlugin) register() {
|
|
p.handlers = bot.HandlerTable{
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false,
|
|
Regex: regexp.MustCompile(`(?i)^(?P<who>\S+) says-bridge (?P<start>.+)\|(?P<end>.+)$`),
|
|
Handler: func(r bot.Request) bool {
|
|
who := r.Values["who"]
|
|
start := strings.Fields(strings.ToLower(r.Values["start"]))
|
|
end := strings.Fields(strings.ToLower(r.Values["end"]))
|
|
return p.sayIt(r, p.getBabbleWithBookends(who, start, end))
|
|
}},
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false,
|
|
Regex: regexp.MustCompile(`(?i)^(?P<who>\S+) says-tail (?P<what>.*)$`),
|
|
Handler: func(r bot.Request) bool {
|
|
who := r.Values["who"]
|
|
what := strings.Fields(strings.ToLower(r.Values["what"]))
|
|
return p.sayIt(r, p.getBabbleWithSuffix(who, what))
|
|
}},
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false,
|
|
Regex: regexp.MustCompile(`(?i)^(?P<who>\S+) says-middle-out (?P<what>.*)$`),
|
|
Handler: func(r bot.Request) bool {
|
|
who := r.Values["who"]
|
|
what := strings.ToLower(r.Values["what"])
|
|
tokens := strings.Fields(what)
|
|
saidSomething := false
|
|
saidWhat := ""
|
|
|
|
saidWhatStart := p.getBabbleWithSuffix(who, tokens)
|
|
saidSomethingStart := saidWhatStart != ""
|
|
neverSaidLooksLike := fmt.Sprintf("%s never said", who)
|
|
if !saidSomethingStart || strings.HasPrefix(saidWhatStart, neverSaidLooksLike) {
|
|
saidSomething = saidSomethingStart
|
|
saidWhat = saidWhatStart
|
|
} else {
|
|
saidWhatEnd := p.getBabble(who, tokens)
|
|
saidSomethingEnd := saidWhatEnd != ""
|
|
saidSomething = saidSomethingStart && saidSomethingEnd
|
|
if saidSomething {
|
|
saidWhat = saidWhatStart + strings.TrimPrefix(saidWhatEnd, what)
|
|
}
|
|
}
|
|
return p.sayIt(r, saidWhat)
|
|
}},
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^(?P<who>\S+) (says (?P<what>.*)?|says)$`),
|
|
Handler: func(r bot.Request) bool {
|
|
who := r.Values["who"]
|
|
what := strings.Fields(strings.ToLower(r.Values["what"]))
|
|
return p.sayIt(r, p.getBabble(who, what))
|
|
}},
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false,
|
|
Regex: regexp.MustCompile(`(?i)^initialize babbler for (?P<who>\S+)$`),
|
|
Handler: func(r bot.Request) bool {
|
|
who := r.Values["who"]
|
|
return p.sayIt(r, p.initializeBabbler(who))
|
|
}},
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false,
|
|
Regex: regexp.MustCompile(`(?i)^merge babbler (?P<from>\S+) into (?P<to>\S+)$`),
|
|
Handler: func(r bot.Request) bool {
|
|
from, to := r.Values["from"], r.Values["to"]
|
|
return p.sayIt(r, p.merge(from, to))
|
|
}},
|
|
bot.HandlerSpec{Kind: bot.Message, IsCmd: false,
|
|
Regex: regexp.MustCompile(`.*`),
|
|
Handler: func(r bot.Request) bool {
|
|
p.addToBabbler(r.Msg.User.Name, strings.ToLower(r.Msg.Body))
|
|
return false
|
|
}},
|
|
}
|
|
p.Bot.RegisterTable(p, p.handlers)
|
|
p.Bot.Register(p, bot.Help, p.help)
|
|
}
|
|
|
|
func (p *BabblerPlugin) sayIt(r bot.Request, what string) bool {
|
|
if what != "" {
|
|
p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, what)
|
|
}
|
|
return what != ""
|
|
}
|
|
|
|
func (p *BabblerPlugin) help(c bot.Connector, kind bot.Kind, msg msg.Message, args ...interface{}) bool {
|
|
commands := []string{
|
|
"initialize babbler for seabass",
|
|
"merge babbler drseabass into seabass",
|
|
"seabass says ...",
|
|
"seabass says-tail ...",
|
|
"seabass says-middle-out ...",
|
|
"seabass says-bridge ... | ...",
|
|
}
|
|
p.Bot.Send(c, bot.Message, msg.Channel, strings.Join(commands, "\n\n"))
|
|
return true
|
|
}
|
|
|
|
func (p *BabblerPlugin) makeBabbler(name string) (*Babbler, error) {
|
|
b := &Babbler{
|
|
Name: name,
|
|
}
|
|
err := p.store.Insert(bh.NextSequence(), b)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
return b, err
|
|
}
|
|
|
|
func (p *BabblerPlugin) getBabbler(name string) (*Babbler, error) {
|
|
var bblr Babbler
|
|
err := p.store.FindOne(&bblr, bh.Where("babbler").Eq(name))
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
log.Error().Msg("failed to find babbler")
|
|
return nil, NO_BABBLER
|
|
}
|
|
log.Error().Err(err).Msg("encountered problem in babbler lookup")
|
|
return nil, err
|
|
}
|
|
return &bblr, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getOrCreateBabbler(name string) (*Babbler, error) {
|
|
babbler, err := p.getBabbler(name)
|
|
if err == NO_BABBLER {
|
|
babbler, err = p.makeBabbler(name)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
|
|
quotes := remember.AllQuotesFrom(p.store, babbler.Name)
|
|
for _, q := range quotes {
|
|
if err = p.addToMarkovChain(babbler, q.Tidbit); err != nil {
|
|
log.Error().Err(err)
|
|
}
|
|
}
|
|
}
|
|
return babbler, err
|
|
}
|
|
|
|
func (p *BabblerPlugin) getWord(word string) (*BabblerWord, error) {
|
|
var w BabblerWord
|
|
err := p.store.FindOne(&w, bh.Where("word").Eq(word).Limit(1))
|
|
if err != nil {
|
|
if err == bh.ErrNotFound {
|
|
return nil, NEVER_SAID
|
|
}
|
|
return nil, err
|
|
}
|
|
return &w, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) createNewWord(word string) (*BabblerWord, error) {
|
|
w := &BabblerWord{Word: word}
|
|
err := p.store.Insert(bh.NextSequence(), w)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
return w, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getOrCreateWord(word string) (*BabblerWord, error) {
|
|
if w, err := p.getWord(word); err == NEVER_SAID {
|
|
return p.createNewWord(word)
|
|
} else {
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
}
|
|
return w, err
|
|
}
|
|
}
|
|
|
|
func (p *BabblerPlugin) getBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) {
|
|
w, err := p.getWord(word)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var node BabblerNode
|
|
err = p.store.FindOne(&node, bh.Where("babblerId").Eq(babbler.BabblerId).And("wordId").Eq(w.WordId))
|
|
if err != nil {
|
|
if err == bh.ErrNotFound {
|
|
return nil, NEVER_SAID
|
|
}
|
|
return nil, err
|
|
}
|
|
return &node, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) createBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) {
|
|
w, err := p.getOrCreateWord(word)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
|
|
bn := &BabblerNode{
|
|
WordId: w.WordId,
|
|
Root: 0,
|
|
RootFrequency: 0,
|
|
}
|
|
|
|
err = p.store.Insert(bh.NextSequence(), bn)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
|
|
return bn, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getOrCreateBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) {
|
|
node, err := p.getBabblerNode(babbler, word)
|
|
if err != nil {
|
|
return p.createBabblerNode(babbler, word)
|
|
}
|
|
return node, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) incrementRootWordFrequency(babbler *Babbler, word string) (*BabblerNode, error) {
|
|
node, err := p.getOrCreateBabblerNode(babbler, word)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
err = p.store.UpdateMatching(BabblerNode{}, bh.Where("id").Eq(node.NodeId), func(record interface{}) error {
|
|
r := record.(BabblerNode)
|
|
r.RootFrequency += 1
|
|
r.Root = 1
|
|
return p.store.Update(r.NodeId, r)
|
|
})
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
node.RootFrequency += 1
|
|
return node, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getBabblerArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) {
|
|
var arc BabblerArc
|
|
err := p.store.FindOne(&arc, bh.Where("fromNodeId").Eq(fromNode.NodeId).And("toNodeId").Eq(toNode.NodeId))
|
|
if err != nil {
|
|
if err == bh.ErrNotFound {
|
|
return nil, NEVER_SAID
|
|
}
|
|
return nil, err
|
|
}
|
|
return &arc, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) incrementWordArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) {
|
|
affectedRows := 0
|
|
err := p.store.UpdateMatching(BabblerArc{},
|
|
bh.Where("fromNodeId").Eq(fromNode.NodeId).And("toNodeId").Eq(toNode.NodeId),
|
|
func(record interface{}) error {
|
|
affectedRows++
|
|
r := record.(BabblerArc)
|
|
r.Frequency += 1
|
|
return p.store.Update(r.ArcId, r)
|
|
})
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
|
|
if affectedRows == 0 {
|
|
p.store.Insert(bh.NextSequence(), BabblerArc{
|
|
FromNodeId: fromNode.NodeId,
|
|
ToNodeId: toNode.NodeId,
|
|
Frequency: 1,
|
|
})
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return p.getBabblerArc(fromNode, toNode)
|
|
}
|
|
|
|
func (p *BabblerPlugin) incrementFinalWordArcHelper(babbler *Babbler, node *BabblerNode) (*BabblerArc, error) {
|
|
nextNode, err := p.getOrCreateBabblerNode(babbler, " ")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return p.incrementWordArc(node, nextNode)
|
|
}
|
|
|
|
func (p *BabblerPlugin) addToMarkovChain(babbler *Babbler, phrase string) error {
|
|
words := strings.Fields(strings.ToLower(phrase))
|
|
|
|
if len(words) <= 0 {
|
|
return nil
|
|
}
|
|
|
|
curNode, err := p.incrementRootWordFrequency(babbler, words[0])
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
|
|
for i := 1; i < len(words); i++ {
|
|
nextNode, err := p.getOrCreateBabblerNode(babbler, words[i])
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
_, err = p.incrementWordArc(curNode, nextNode)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
curNode = nextNode
|
|
}
|
|
|
|
_, err = p.incrementFinalWordArcHelper(babbler, curNode)
|
|
return err
|
|
}
|
|
|
|
func (p *BabblerPlugin) getWeightedRootNode(babbler *Babbler) (*BabblerNode, *BabblerWord, error) {
|
|
rootNodes := []*BabblerNode{}
|
|
err := p.store.Find(&rootNodes, bh.Where("babblerId").Eq(babbler.BabblerId).And("root").Eq(1))
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
|
|
total := int64(0)
|
|
|
|
for _, n := range rootNodes {
|
|
total += n.RootFrequency
|
|
}
|
|
|
|
if len(rootNodes) == 0 {
|
|
return nil, nil, SAID_NOTHING
|
|
}
|
|
|
|
which := rand.Int63n(total)
|
|
total = 0
|
|
for _, node := range rootNodes {
|
|
total += node.RootFrequency
|
|
if total >= which {
|
|
w, err := getWord(p.store, node.WordId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
return node, w, nil
|
|
}
|
|
|
|
}
|
|
log.Fatal().Msg("failed to find weighted root word")
|
|
return nil, nil, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode, *BabblerWord, error) {
|
|
arcs := []BabblerArc{}
|
|
err := p.store.Find(&arcs, bh.Where("fromNodeId").Eq(fromNode.NodeId))
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
total := int64(0)
|
|
for _, a := range arcs {
|
|
total += a.Frequency
|
|
}
|
|
|
|
if len(arcs) == 0 {
|
|
return nil, nil, errors.New("missing arcs")
|
|
}
|
|
|
|
which := rand.Int63n(total)
|
|
total = 0
|
|
for _, arc := range arcs {
|
|
|
|
total += arc.Frequency
|
|
|
|
if total >= which {
|
|
node, err := getNode(p.store, arc.ToNodeId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
|
|
w, err := getWord(p.store, node.WordId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
return node, w, nil
|
|
}
|
|
|
|
}
|
|
log.Fatal().Msg("failed to find weighted next word")
|
|
return nil, nil, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getWeightedPreviousWord(toNode *BabblerNode) (*BabblerNode, *BabblerWord, bool, error) {
|
|
arcs := []*BabblerArc{}
|
|
err := p.store.Find(&arcs, bh.Where("toNodeId").Eq(toNode.NodeId))
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, false, err
|
|
}
|
|
|
|
total := int64(0)
|
|
for _, arc := range arcs {
|
|
total += arc.Frequency
|
|
}
|
|
|
|
if len(arcs) == 0 {
|
|
return nil, nil, true, nil
|
|
}
|
|
|
|
which := rand.Int63n(total + toNode.RootFrequency)
|
|
|
|
//terminate the babble
|
|
if which >= total {
|
|
return nil, nil, true, nil
|
|
}
|
|
|
|
total = 0
|
|
for _, arc := range arcs {
|
|
total += arc.Frequency
|
|
|
|
if total >= which {
|
|
node, err := getNode(p.store, arc.FromNodeId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, false, err
|
|
}
|
|
|
|
w, err := getWord(p.store, node.WordId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, false, err
|
|
}
|
|
return node, w, false, nil
|
|
}
|
|
}
|
|
log.Fatal().Msg("failed to find weighted previous word")
|
|
return nil, nil, false, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) verifyPhrase(babbler *Babbler, phrase []string) (*BabblerNode, *BabblerNode, error) {
|
|
curNode, err := p.getBabblerNode(babbler, phrase[0])
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
firstNode := curNode
|
|
for i := 1; i < len(phrase); i++ {
|
|
nextNode, err := p.getBabblerNode(babbler, phrase[i])
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
_, err = p.getBabblerArc(curNode, nextNode)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, nil, err
|
|
}
|
|
curNode = nextNode
|
|
}
|
|
|
|
return firstNode, curNode, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) babble(who string) (string, error) {
|
|
return p.babbleSeed(who, []string{})
|
|
}
|
|
|
|
func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, error) {
|
|
babbler, err := p.getBabbler(babblerName)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", nil
|
|
}
|
|
|
|
words := seed
|
|
|
|
var curNode *BabblerNode
|
|
var curWord *BabblerWord
|
|
if len(seed) == 0 {
|
|
curNode, curWord, err = p.getWeightedRootNode(babbler)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
words = append(words, curWord.Word)
|
|
} else {
|
|
_, curNode, err = p.verifyPhrase(babbler, seed)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
for {
|
|
curNode, curWord, err = p.getWeightedNextWord(curNode)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
if curWord.Word == " " {
|
|
break
|
|
}
|
|
words = append(words, curWord.Word)
|
|
|
|
if len(words) >= 250 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return strings.TrimSpace(strings.Join(words, " ")), nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoName, otherName string) error {
|
|
intoNode, err := p.getOrCreateBabblerNode(intoBabbler, "<"+intoName+">")
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
otherNode, err := p.getOrCreateBabblerNode(otherBabbler, "<"+otherName+">")
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
|
|
mapping := map[int64]*BabblerNode{}
|
|
|
|
nodes := []*BabblerNode{}
|
|
err = p.store.Find(&nodes, bh.Where("babblerId").Eq(otherBabbler.BabblerId))
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
|
|
for _, node := range nodes {
|
|
if node.NodeId == otherNode.NodeId {
|
|
node.WordId = intoNode.WordId
|
|
}
|
|
|
|
affected := 0
|
|
if node.Root > 0 {
|
|
err = p.store.UpdateMatching(BabblerNode{},
|
|
bh.Where("babblerId").Eq(intoBabbler.BabblerId).And("wordId").Eq(node.WordId),
|
|
func(record interface{}) error {
|
|
affected++
|
|
r := record.(BabblerNode)
|
|
r.RootFrequency += node.RootFrequency
|
|
r.Root = 1
|
|
return p.store.Update(r.BabblerId, r)
|
|
})
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
}
|
|
} else {
|
|
err = p.store.UpdateMatching(BabblerNode{},
|
|
bh.Where("babblerId").Eq(intoBabbler.BabblerId).And("wordId").Eq(node.WordId),
|
|
func(record interface{}) error {
|
|
affected++
|
|
r := record.(BabblerNode)
|
|
r.RootFrequency += node.RootFrequency
|
|
return p.store.Update(r.BabblerId, r)
|
|
})
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
}
|
|
}
|
|
|
|
if err != nil || affected == 0 {
|
|
node.BabblerId = intoBabbler.BabblerId
|
|
err = p.store.Insert(bh.NextSequence(), &node)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
var updatedNode BabblerNode
|
|
err = p.store.FindOne(&updatedNode,
|
|
bh.Where("babblerId").Eq(intoBabbler.BabblerId).And("wordId").Eq(node.WordId))
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return err
|
|
}
|
|
|
|
mapping[node.NodeId] = &updatedNode
|
|
}
|
|
|
|
for oldNodeId, newNode := range mapping {
|
|
arcs := []*BabblerArc{}
|
|
err = p.store.Find(&arcs, bh.Where("fromNodeId").Eq(oldNodeId))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, arc := range arcs {
|
|
_, err := p.incrementWordArc(newNode, mapping[arc.ToNodeId])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) {
|
|
babbler, err := p.getBabbler(babblerName)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", nil
|
|
}
|
|
|
|
firstNode, curNode, err := p.verifyPhrase(babbler, seed)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
|
|
words := []string{}
|
|
var curWord *BabblerWord
|
|
var shouldTerminate bool
|
|
curNode = firstNode
|
|
for {
|
|
curNode, curWord, shouldTerminate, err = p.getWeightedPreviousWord(curNode)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
|
|
if shouldTerminate {
|
|
break
|
|
}
|
|
|
|
words = append(words, curWord.Word)
|
|
|
|
if len(words) >= 250 {
|
|
break
|
|
}
|
|
}
|
|
|
|
for i := 0; i < len(words)/2; i++ {
|
|
index := len(words) - (i + 1)
|
|
words[i], words[index] = words[index], words[i]
|
|
}
|
|
|
|
words = append(words, seed...)
|
|
|
|
return strings.TrimSpace(strings.Join(words, " ")), nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getNextArcs(babblerNodeId int64) ([]*BabblerArc, error) {
|
|
arcs := []*BabblerArc{}
|
|
err := p.store.Find(&arcs, bh.Where("fromNodeId").Eq(babblerNodeId))
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return arcs, err
|
|
}
|
|
return arcs, nil
|
|
}
|
|
|
|
func (p *BabblerPlugin) getBabblerNodeById(nodeId int64) (*BabblerNode, error) {
|
|
node, err := getNode(p.store, nodeId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return nil, err
|
|
}
|
|
return node, nil
|
|
}
|
|
|
|
func shuffle(a []*BabblerArc) {
|
|
for i := range a {
|
|
j := rand.Intn(i + 1)
|
|
a[i], a[j] = a[j], a[i]
|
|
}
|
|
}
|
|
|
|
func (p *BabblerPlugin) babbleSeedBookends(babblerName string, start, end []string) (string, error) {
|
|
babbler, err := p.getBabbler(babblerName)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", nil
|
|
}
|
|
|
|
_, startWordNode, err := p.verifyPhrase(babbler, start)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
|
|
endWordNode, _, err := p.verifyPhrase(babbler, end)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
|
|
type searchNode struct {
|
|
babblerNodeId int64
|
|
previous *searchNode
|
|
}
|
|
|
|
open := []*searchNode{{startWordNode.NodeId, nil}}
|
|
closed := map[int64]*searchNode{startWordNode.NodeId: open[0]}
|
|
goalNodeId := int64(-1)
|
|
|
|
for i := 0; i < len(open) && i < 1000; i++ {
|
|
cur := open[i]
|
|
|
|
arcs, err := p.getNextArcs(cur.babblerNodeId)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
//add a little randomization in through child ordering
|
|
shuffle(arcs)
|
|
|
|
for _, arc := range arcs {
|
|
if _, ok := closed[arc.ToNodeId]; !ok {
|
|
child := &searchNode{arc.ToNodeId, cur}
|
|
open = append(open, child)
|
|
closed[arc.ToNodeId] = child
|
|
|
|
if arc.ToNodeId == endWordNode.NodeId {
|
|
goalNodeId = cur.babblerNodeId
|
|
//add a little randomization in through maybe searching beyond this solution?
|
|
if rand.Intn(4) == 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if goalNodeId == -1 {
|
|
return "", errors.New("couldn't find path")
|
|
} else if closed[goalNodeId].previous == nil {
|
|
seeds := append(start, end...)
|
|
return strings.Join(seeds, " "), nil
|
|
}
|
|
|
|
words := []string{}
|
|
|
|
curSearchNode := closed[goalNodeId]
|
|
|
|
for {
|
|
cur, err := p.getBabblerNodeById(curSearchNode.babblerNodeId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
w, err := getWord(p.store, cur.WordId)
|
|
if err != nil {
|
|
log.Error().Err(err)
|
|
return "", err
|
|
}
|
|
words = append(words, w.Word)
|
|
|
|
curSearchNode = closed[curSearchNode.previous.babblerNodeId]
|
|
|
|
if curSearchNode.previous == nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
for i := 0; i < len(words)/2; i++ {
|
|
index := len(words) - (i + 1)
|
|
words[i], words[index] = words[index], words[i]
|
|
}
|
|
|
|
words = append(start, words...)
|
|
words = append(words, end...)
|
|
|
|
return strings.Join(words, " "), nil
|
|
}
|