// © 2013 the CatBase Authors under the WTFPL. See AUTHORS for the list of authors. package babbler import ( "errors" "fmt" bh "github.com/timshannon/bolthold" "github.com/velour/catbase/plugins/remember" "math" "math/rand" "regexp" "strings" "github.com/rs/zerolog/log" "github.com/velour/catbase/bot" "github.com/velour/catbase/bot/msg" ) var ( NO_BABBLER = errors.New("babbler not found") SAID_NOTHING = errors.New("hasn't said anything yet") NEVER_SAID = errors.New("never said that") ) type BabblerPlugin struct { Bot bot.Bot store *bh.Store WithGoRoutines bool handlers bot.HandlerTable } type Babbler struct { BabblerID uint64 `db:"BabblerID" boltholdKey:"BabblerID"` Name string `db:"Name"` } func getBabbler(store *bh.Store, id int64) (*Babbler, error) { res := &Babbler{} err := store.Get(id, res) return res, err } type BabblerWord struct { WordID uint64 `db:"WordID" boltholdKey:"WordID"` Word string `db:"Word"` } func getWord(store *bh.Store, id uint64) (*BabblerWord, error) { res := &BabblerWord{} err := store.Get(id, res) return res, err } type BabblerNode struct { NodeID uint64 `db:"NodeID" boltholdKey:"NodeID"` BabblerID uint64 `db:"BabblerID"` WordID uint64 `db:"WordID"` Root uint64 `db:"Root"` RootFrequency int64 `db:"RootFrequency"` } func getNode(store *bh.Store, id uint64) (*BabblerNode, error) { res := &BabblerNode{} all := []BabblerNode{} err := store.Find(&all, &bh.Query{}) if err != nil { log.Error().Err(err).Msg("error finding all") return nil, err } err = store.Get(id, res) if err != nil { log.Error().Err(err).Msgf("error getting %v", id) return nil, err } return res, err } type BabblerArc struct { ArcID uint64 `db:"ArcID" boltholdKey:"ArcID"` FromNodeID uint64 `db:"FromNodeID"` ToNodeID uint64 `db:"ToNodeID"` Frequency int64 `db:"Frequency"` } func getArc(store *bh.Store, id uint64) (*BabblerArc, error) { res := &BabblerArc{} err := store.Get(id, res) return res, err } func New(b bot.Bot) *BabblerPlugin { plugin := &BabblerPlugin{ Bot: b, store: b.Store(), WithGoRoutines: true, } plugin.createNewWord("") plugin.register() return plugin } func (p *BabblerPlugin) register() { p.handlers = bot.HandlerTable{ bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^(?P\S+) says-bridge (?P.+)\|(?P.+)$`), Handler: func(r bot.Request) bool { who := r.Values["who"] start := strings.Fields(strings.ToLower(r.Values["start"])) end := strings.Fields(strings.ToLower(r.Values["end"])) return p.sayIt(r, p.getBabbleWithBookends(who, start, end)) }}, bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^(?P\S+) says-tail (?P.*)$`), Handler: func(r bot.Request) bool { who := r.Values["who"] what := strings.Fields(strings.ToLower(r.Values["what"])) return p.sayIt(r, p.getBabbleWithSuffix(who, what)) }}, bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^(?P\S+) says-middle-out (?P.*)$`), Handler: func(r bot.Request) bool { who := r.Values["who"] what := strings.ToLower(r.Values["what"]) tokens := strings.Fields(what) saidSomething := false saidWhat := "" saidWhatStart := p.getBabbleWithSuffix(who, tokens) saidSomethingStart := saidWhatStart != "" neverSaidLooksLike := fmt.Sprintf("%s never said", who) if !saidSomethingStart || strings.HasPrefix(saidWhatStart, neverSaidLooksLike) { saidSomething = saidSomethingStart saidWhat = saidWhatStart } else { saidWhatEnd := p.getBabble(who, tokens) saidSomethingEnd := saidWhatEnd != "" saidSomething = saidSomethingStart && saidSomethingEnd if saidSomething { saidWhat = saidWhatStart + strings.TrimPrefix(saidWhatEnd, what) } } return p.sayIt(r, saidWhat) }}, bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^(?P\S+) (says (?P.*)?|says)$`), Handler: func(r bot.Request) bool { who := r.Values["who"] what := strings.Fields(strings.ToLower(r.Values["what"])) return p.sayIt(r, p.getBabble(who, what)) }}, bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^initialize babbler for (?P\S+)$`), Handler: func(r bot.Request) bool { who := r.Values["who"] return p.sayIt(r, p.initializeBabbler(who)) }}, bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`(?i)^merge babbler (?P\S+) into (?P\S+)$`), Handler: func(r bot.Request) bool { from, to := r.Values["from"], r.Values["to"] return p.sayIt(r, p.merge(from, to)) }}, bot.HandlerSpec{Kind: bot.Message, IsCmd: false, Regex: regexp.MustCompile(`.*`), Handler: func(r bot.Request) bool { p.addToBabbler(r.Msg.User.Name, strings.ToLower(r.Msg.Body)) return false }}, } p.Bot.RegisterTable(p, p.handlers) p.Bot.Register(p, bot.Help, p.help) } func (p *BabblerPlugin) sayIt(r bot.Request, what string) bool { if what != "" { p.Bot.Send(r.Conn, bot.Message, r.Msg.Channel, what) } return what != "" } func (p *BabblerPlugin) help(c bot.Connector, kind bot.Kind, msg msg.Message, args ...interface{}) bool { commands := []string{ "initialize babbler for seabass", "merge babbler drseabass into seabass", "seabass says ...", "seabass says-tail ...", "seabass says-middle-out ...", "seabass says-bridge ... | ...", } p.Bot.Send(c, bot.Message, msg.Channel, strings.Join(commands, "\n\n")) return true } func (p *BabblerPlugin) makeBabbler(name string) (*Babbler, error) { b := &Babbler{ Name: name, } err := p.store.Insert(bh.NextSequence(), b) if err != nil { log.Error().Err(err) return nil, err } return b, err } func (p *BabblerPlugin) getBabbler(name string) (*Babbler, error) { var bblr Babbler err := p.store.FindOne(&bblr, bh.Where("Name").Eq(name)) if err != nil { if err == bh.ErrNotFound { log.Error().Msgf("failed to find babbler for %s", name) return nil, NO_BABBLER } log.Error().Err(err).Msg("encountered problem in babbler lookup for %s") return nil, err } return &bblr, nil } func (p *BabblerPlugin) getOrCreateBabbler(name string) (*Babbler, error) { babbler, err := p.getBabbler(name) if err == NO_BABBLER { babbler, err = p.makeBabbler(name) if err != nil { log.Error().Err(err).Msg("error making babbler") return nil, err } quotes := remember.AllQuotesFrom(p.store, babbler.Name) for _, q := range quotes { if err = p.addToMarkovChain(babbler, q.Tidbit); err != nil { log.Error().Err(err).Msg("error adding to chain") } } } return babbler, err } func (p *BabblerPlugin) getWord(word string) (*BabblerWord, error) { var w BabblerWord err := p.store.FindOne(&w, bh.Where("Word").Eq(word).Limit(1)) if err != nil { if err == bh.ErrNotFound { return nil, NEVER_SAID } return nil, err } return &w, nil } func (p *BabblerPlugin) createNewWord(word string) (*BabblerWord, error) { w := &BabblerWord{Word: word} err := p.store.Insert(bh.NextSequence(), w) if err != nil { log.Error().Err(err) return nil, err } return w, nil } func (p *BabblerPlugin) getOrCreateWord(word string) (*BabblerWord, error) { if w, err := p.getWord(word); err == NEVER_SAID { return p.createNewWord(word) } else { if err != nil { log.Error().Err(err) } return w, err } } func (p *BabblerPlugin) getBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { w, err := p.getWord(word) if err != nil { return nil, err } var node BabblerNode err = p.store.FindOne(&node, bh.Where("BabblerID").Eq(babbler.BabblerID).And("WordID").Eq(w.WordID)) if err != nil { if err == bh.ErrNotFound { return nil, NEVER_SAID } return nil, err } return &node, nil } func (p *BabblerPlugin) createBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { w, err := p.getOrCreateWord(word) if err != nil { log.Error().Err(err) return nil, err } bn := &BabblerNode{ BabblerID: babbler.BabblerID, WordID: w.WordID, Root: 0, RootFrequency: 0, } err = p.store.Insert(bh.NextSequence(), bn) if err != nil { log.Error().Err(err) return nil, err } return bn, nil } func (p *BabblerPlugin) getOrCreateBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { node, err := p.getBabblerNode(babbler, word) if err != nil { return p.createBabblerNode(babbler, word) } return node, nil } func (p *BabblerPlugin) incrementRootWordFrequency(babbler *Babbler, word string) (*BabblerNode, error) { node, err := p.getOrCreateBabblerNode(babbler, word) if err != nil { log.Error().Err(err).Msg("error getOrCreateBabblerNode") return nil, err } node.RootFrequency += 1 node.Root = 1 err = p.store.Update(node.NodeID, node) if err != nil { log.Error().Err(err).Msg("error updating matching") return nil, err } node.RootFrequency += 1 return node, nil } func (p *BabblerPlugin) getBabblerArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) { var arc BabblerArc err := p.store.FindOne(&arc, bh.Where("FromNodeID").Eq(fromNode.NodeID).And("ToNodeID").Eq(toNode.NodeID)) if err != nil { if err == bh.ErrNotFound { return nil, NEVER_SAID } return nil, err } return &arc, nil } func (p *BabblerPlugin) incrementWordArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) { affectedRows := 0 err := p.store.UpdateMatching(BabblerArc{}, bh.Where("FromNodeID").Eq(fromNode.NodeID).And("ToNodeID").Eq(toNode.NodeID), func(record interface{}) error { affectedRows++ r, ok := record.(*BabblerArc) if !ok { return fmt.Errorf("incorrect type: expected BabblerArc, got %T", record) } r.Frequency += 1 return nil }) if err != nil { log.Error().Err(err).Msg("error updating arcs") return nil, err } if affectedRows == 0 { err = p.store.Insert(bh.NextSequence(), BabblerArc{ FromNodeID: fromNode.NodeID, ToNodeID: toNode.NodeID, Frequency: 1, }) if err != nil { log.Error().Err(err).Msg("error inserting arc") return nil, err } } return p.getBabblerArc(fromNode, toNode) } func (p *BabblerPlugin) incrementFinalWordArcHelper(babbler *Babbler, node *BabblerNode) (*BabblerArc, error) { nextNode, err := p.getOrCreateBabblerNode(babbler, " ") if err != nil { return nil, err } return p.incrementWordArc(node, nextNode) } func (p *BabblerPlugin) addToMarkovChain(babbler *Babbler, phrase string) error { words := strings.Fields(strings.ToLower(phrase)) if len(words) <= 0 { return nil } curNode, err := p.incrementRootWordFrequency(babbler, words[0]) if err != nil { log.Error().Err(err).Msg("incrementRootWordFrequency") return err } for i := 1; i < len(words); i++ { nextNode, err := p.getOrCreateBabblerNode(babbler, words[i]) if err != nil { log.Error().Err(err).Msg("getOrCreateBabblerNode") return err } _, err = p.incrementWordArc(curNode, nextNode) if err != nil { log.Error().Err(err).Msg("incrementWordArc") return err } curNode = nextNode } _, err = p.incrementFinalWordArcHelper(babbler, curNode) return err } func (p *BabblerPlugin) getWeightedRootNode(babbler *Babbler) (*BabblerNode, *BabblerWord, error) { rootNodes := []*BabblerNode{} err := p.store.Find(&rootNodes, bh.Where("BabblerID").Eq(babbler.BabblerID).And("Root").Eq(uint64(1))) if err != nil { log.Error().Err(err) return nil, nil, err } total := int64(0) for _, n := range rootNodes { total += n.RootFrequency } if len(rootNodes) == 0 { return nil, nil, SAID_NOTHING } which := rand.Int63n(total) total = 0 for _, node := range rootNodes { total += node.RootFrequency if total >= which { w, err := getWord(p.store, node.WordID) if err != nil { log.Error().Err(err) return nil, nil, err } return node, w, nil } } log.Fatal().Msg("failed to find weighted root word") return nil, nil, nil } func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode, *BabblerWord, error) { arcs := []BabblerArc{} err := p.store.Find(&arcs, bh.Where("FromNodeID").Eq(fromNode.NodeID)) if err != nil { log.Error().Err(err) return nil, nil, err } total := int64(0) for _, a := range arcs { total += a.Frequency } if len(arcs) == 0 { return nil, nil, errors.New("missing arcs") } which := rand.Int63n(total) total = 0 for _, arc := range arcs { total += arc.Frequency if total >= which { node, err := getNode(p.store, arc.ToNodeID) if err != nil { log.Error().Err(err).Msg("getNode") return nil, nil, err } w, err := getWord(p.store, node.WordID) if err != nil { log.Error().Err(err).Msg("getWord") return nil, nil, err } return node, w, nil } } log.Fatal().Msg("failed to find weighted next word") return nil, nil, nil } func (p *BabblerPlugin) getWeightedPreviousWord(toNode *BabblerNode) (*BabblerNode, *BabblerWord, bool, error) { arcs := []*BabblerArc{} err := p.store.Find(&arcs, bh.Where("ToNodeID").Eq(toNode.NodeID)) if err != nil { log.Error().Err(err) return nil, nil, false, err } total := int64(0) for _, arc := range arcs { total += arc.Frequency } if len(arcs) == 0 { return nil, nil, true, nil } which := rand.Int63n(total + toNode.RootFrequency) //terminate the babble if which >= total { return nil, nil, true, nil } total = 0 for _, arc := range arcs { total += arc.Frequency if total >= which { node, err := getNode(p.store, arc.FromNodeID) if err != nil { log.Error().Err(err) return nil, nil, false, err } w, err := getWord(p.store, node.WordID) if err != nil { log.Error().Err(err) return nil, nil, false, err } return node, w, false, nil } } log.Fatal().Msg("failed to find weighted previous word") return nil, nil, false, nil } func (p *BabblerPlugin) verifyPhrase(babbler *Babbler, phrase []string) (*BabblerNode, *BabblerNode, error) { curNode, err := p.getBabblerNode(babbler, phrase[0]) if err != nil { log.Error().Err(err) return nil, nil, err } firstNode := curNode for i := 1; i < len(phrase); i++ { nextNode, err := p.getBabblerNode(babbler, phrase[i]) if err != nil { log.Error().Err(err) return nil, nil, err } _, err = p.getBabblerArc(curNode, nextNode) if err != nil { log.Error().Err(err) return nil, nil, err } curNode = nextNode } return firstNode, curNode, nil } func (p *BabblerPlugin) babble(who string) (string, error) { return p.babbleSeed(who, []string{}) } func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, error) { babbler, err := p.getBabbler(babblerName) if err != nil { log.Error().Err(err).Msg("error getting babbler") return "", nil } words := seed var curNode *BabblerNode var curWord *BabblerWord if len(seed) == 0 { curNode, curWord, err = p.getWeightedRootNode(babbler) if err != nil { log.Error().Err(err).Msg("error getWeightedRootNode") return "", err } words = append(words, curWord.Word) } else { _, curNode, err = p.verifyPhrase(babbler, seed) if err != nil { log.Error().Err(err).Msg("error verifyPhrase") return "", err } } for { curNode, curWord, err = p.getWeightedNextWord(curNode) if err != nil { log.Error().Err(err).Msg("err getWeightedNextWord") return "", err } if curWord.Word == " " { break } words = append(words, curWord.Word) if len(words) >= 250 { break } } return strings.TrimSpace(strings.Join(words, " ")), nil } func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoName, otherName string) error { intoNode, err := p.getOrCreateBabblerNode(intoBabbler, "<"+intoName+">") if err != nil { log.Error().Err(err) return err } otherNode, err := p.getOrCreateBabblerNode(otherBabbler, "<"+otherName+">") if err != nil { log.Error().Err(err) return err } mapping := map[uint64]*BabblerNode{} nodes := []*BabblerNode{} err = p.store.Find(&nodes, bh.Where("BabblerID").Eq(otherBabbler.BabblerID)) if err != nil { log.Error().Err(err) return err } for _, node := range nodes { if node.NodeID == otherNode.NodeID { node.WordID = intoNode.WordID } affected := 0 if node.Root > 0 { err = p.store.UpdateMatching(BabblerNode{}, bh.Where("BabblerID").Eq(intoBabbler.BabblerID).And("WordID").Eq(node.WordID), func(record interface{}) error { r, ok := record.(*BabblerNode) if !ok { return fmt.Errorf("expected BabblerNode, got %T", record) } affected++ r.RootFrequency += node.RootFrequency r.Root = 1 return nil }) if err != nil { log.Error().Err(err) } } else { err = p.store.UpdateMatching(BabblerNode{}, bh.Where("BabblerID").Eq(intoBabbler.BabblerID).And("WordID").Eq(node.WordID), func(record interface{}) error { r, ok := record.(*BabblerNode) if !ok { return fmt.Errorf("expected BabblerNode, got %T", record) } affected++ r.RootFrequency += node.RootFrequency return nil }) if err != nil { log.Error().Err(err) } } if err != nil || affected == 0 { node.BabblerID = intoBabbler.BabblerID err = p.store.Insert(bh.NextSequence(), node) if err != nil { log.Error().Err(err).Msg("error inserting node") return err } log.Debug().Msgf("Inserted: %+v", node) } else { log.Error().Err(err).Int("affected", affected).Msgf("problem before insert") } var updatedNode BabblerNode err = p.store.FindOne(&updatedNode, bh.Where("BabblerID").Eq(intoBabbler.BabblerID).And("WordID").Eq(node.WordID)) if err != nil { log.Error().Err(err).Msg("error finding updated node and also why do we need this?") return err } mapping[node.NodeID] = &updatedNode } for oldNodeId, newNode := range mapping { arcs := []*BabblerArc{} err = p.store.Find(&arcs, bh.Where("FromNodeID").Eq(oldNodeId)) if err != nil { return err } for _, arc := range arcs { _, err := p.incrementWordArc(newNode, mapping[arc.ToNodeID]) if err != nil { return err } } } return err } func (p *BabblerPlugin) babbleSeedSuffix(babblerName string, seed []string) (string, error) { babbler, err := p.getBabbler(babblerName) if err != nil { log.Error().Err(err) return "", nil } firstNode, curNode, err := p.verifyPhrase(babbler, seed) if err != nil { log.Error().Err(err) return "", err } words := []string{} var curWord *BabblerWord var shouldTerminate bool curNode = firstNode for { curNode, curWord, shouldTerminate, err = p.getWeightedPreviousWord(curNode) if err != nil { log.Error().Err(err) return "", err } if shouldTerminate { break } words = append(words, curWord.Word) if len(words) >= 250 { break } } for i := 0; i < len(words)/2; i++ { index := len(words) - (i + 1) words[i], words[index] = words[index], words[i] } words = append(words, seed...) return strings.TrimSpace(strings.Join(words, " ")), nil } func (p *BabblerPlugin) getNextArcs(babblerNodeId uint64) ([]*BabblerArc, error) { arcs := []*BabblerArc{} err := p.store.Find(&arcs, bh.Where("FromNodeID").Eq(babblerNodeId)) if err != nil { log.Error().Err(err) return arcs, err } return arcs, nil } func (p *BabblerPlugin) getBabblerNodeById(nodeId uint64) (*BabblerNode, error) { node, err := getNode(p.store, nodeId) if err != nil { log.Error().Err(err) return nil, err } return node, nil } func shuffle(a []*BabblerArc) { for i := range a { j := rand.Intn(i + 1) a[i], a[j] = a[j], a[i] } } func (p *BabblerPlugin) babbleSeedBookends(babblerName string, start, end []string) (string, error) { babbler, err := p.getBabbler(babblerName) if err != nil { log.Error().Err(err) return "", nil } _, startWordNode, err := p.verifyPhrase(babbler, start) if err != nil { log.Error().Err(err) return "", err } endWordNode, _, err := p.verifyPhrase(babbler, end) if err != nil { log.Error().Err(err) return "", err } type searchNode struct { babblerNodeId uint64 previous *searchNode } open := []*searchNode{{startWordNode.NodeID, nil}} closed := map[uint64]*searchNode{startWordNode.NodeID: open[0]} goalNodeId := uint64(math.MaxUint64) for i := 0; i < len(open) && i < 1000; i++ { cur := open[i] arcs, err := p.getNextArcs(cur.babblerNodeId) if err != nil { return "", err } //add a little randomization in through child ordering shuffle(arcs) for _, arc := range arcs { if _, ok := closed[arc.ToNodeID]; !ok { child := &searchNode{arc.ToNodeID, cur} open = append(open, child) closed[arc.ToNodeID] = child if arc.ToNodeID == endWordNode.NodeID { goalNodeId = cur.babblerNodeId //add a little randomization in through maybe searching beyond this solution? if rand.Intn(4) == 0 { break } } } } } if goalNodeId == math.MaxUint64 { return "", errors.New("couldn't find path") } else if closed[goalNodeId].previous == nil { seeds := append(start, end...) return strings.Join(seeds, " "), nil } words := []string{} curSearchNode := closed[goalNodeId] for { cur, err := p.getBabblerNodeById(curSearchNode.babblerNodeId) if err != nil { log.Error().Err(err) return "", err } w, err := getWord(p.store, cur.WordID) if err != nil { log.Error().Err(err) return "", err } words = append(words, w.Word) curSearchNode = closed[curSearchNode.previous.babblerNodeId] if curSearchNode.previous == nil { break } } for i := 0; i < len(words)/2; i++ { index := len(words) - (i + 1) words[i], words[index] = words[index], words[i] } words = append(start, words...) words = append(words, end...) return strings.Join(words, " "), nil }