diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 00574c0..0c54b45 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -3,6 +3,7 @@ package babbler import ( + "database/sql" "errors" "fmt" "log" @@ -15,12 +16,43 @@ import ( "github.com/velour/catbase/config" ) +var ( + NO_BABBLER = errors.New("babbler not found") + SAID_NOTHING = errors.New("hasn't said anything yet") + NEVER_SAID = errors.New("never said that") +) + type BabblerPlugin struct { Bot bot.Bot db *sqlx.DB config *config.Config } +type Babbler struct { + BabblerId int64 `db:"id"` + Name string `db:"babbler"` +} + +type BabblerWord struct { + WordId int64 `db:"id"` + Word string `db:"word"` +} + +type BabblerNode struct { + NodeId int64 `db:"id"` + BabblerId int64 `db:"babblerId"` + WordId int64 `db:"wordId"` + Root int64 `db:"root"` + RootFrequency int64 `db:"rootFrequency"` +} + +type BabblerArc struct { + ArcId int64 `db:"id"` + FromNodeId int64 `db:"fromNodeId"` + ToNodeId int64 `db:"toNodeId"` + Frequency int64 `db:"frequency"` +} + func New(bot bot.Bot) *BabblerPlugin { log.SetFlags(log.LstdFlags | log.Lshortfile) @@ -31,10 +63,18 @@ func New(bot bot.Bot) *BabblerPlugin { );`); err != nil { log.Fatal(err) } + if _, err := bot.DB().Exec(`create table if not exists babblerWords ( + id integer primary key, + word string + );`); err != nil { + log.Fatal(err) + } + + if _, err := bot.DB().Exec(`create table if not exists babblerNodes ( id integer primary key, babblerId integer, - word string, + wordId integer, root integer, rootFrequency integer );`); err != nil { @@ -43,8 +83,8 @@ func New(bot bot.Bot) *BabblerPlugin { if _, err := bot.DB().Exec(`create table if not exists babblerArcs ( id integer primary key, - fromWordId integer, - toWordId interger, + fromNodeId integer, + toNodeId interger, frequency integer );`); err != nil { log.Fatal(err) @@ -57,6 +97,8 @@ func New(bot bot.Bot) *BabblerPlugin { config: bot.Config(), } + plugin.createNewWord("") + return plugin } @@ -65,97 +107,26 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { tokens := strings.Fields(lowercase) numTokens := len(tokens) + saidSomething := false + saidWhat := "" + if numTokens >= 2 && tokens[1] == "says" { - who := tokens[0] - _, err := p.getBabbler(who) - - if err != nil { - return false - } - - var saying string - if len(tokens) == 2 { - saying, _ = p.babble(who) - } else { - saying, _ = p.babbleSeed(who, tokens[2:]) - } - - if saying != "" { - p.Bot.SendMessage(message.Channel, saying) - return true - } + saidWhat, saidSomething = p.getBabble(tokens) } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { - who := tokens[3] - _, err := p.getOrCreateBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "babbler initialization failed.") - return true - } - p.Bot.SendMessage(message.Channel, "Okay.") - return true + saidWhat, saidSomething = p.initializeBabbler(tokens) } else if strings.Index(lowercase, "batch learn for ") == 0 { - who := tokens[3] - babblerId, err := p.getOrCreateBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "batch learn failed.") - return true - } - - body := strings.Join(tokens[4:], " ") - body = strings.ToLower(body) - - for _, a := range strings.Split(body, ".") { - for _, b := range strings.Split(a, "!") { - for _, c := range strings.Split(b, "?") { - for _, d := range strings.Split(c, "\n") { - trimmed := strings.TrimSpace(d) - if trimmed != "" { - p.addToMarkovChain(babblerId, trimmed) - } - } - } - } - } - - p.Bot.SendMessage(message.Channel, "phew that was tiring.") - return true + saidWhat, saidSomething = p.batchLearn(tokens) } else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 { - if tokens[3] != "into" { - p.Bot.SendMessage(message.Channel, "try using 'merge babbler [x] into [y]'") - return true - } - - who := tokens[2] - into := tokens[4] - - if who == into { - p.Bot.SendMessage(message.Channel, "that's annoying. stop it.") - return true - } - - whoBabbler, err := p.getBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "merge failed.") - return true - } - intoBabbler, err := p.getOrCreateBabbler(into) - if err != nil { - p.Bot.SendMessage(message.Channel, "merge failed.") - return true - } - - p.merge(intoBabbler, whoBabbler, into, who) - - p.Bot.SendMessage(message.Channel, "mooooiggged") - return true + saidWhat, saidSomething = p.merge(tokens) } else { - babblerId, err := p.getOrCreateBabbler(message.User.Name) - if err == nil { - p.addToMarkovChain(babblerId, lowercase) - } + //this should always return "", false + saidWhat, saidSomething = p.addToBabbler(message.User.Name, lowercase) } - return false + if saidSomething { + p.Bot.SendMessage(message.Channel, saidWhat) + } + return saidSomething } func (p *BabblerPlugin) Help(channel string, parts []string) { @@ -174,391 +145,513 @@ func (p *BabblerPlugin) RegisterWeb() *string { return nil } -func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) { - res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, babbler) +func (p *BabblerPlugin) makeBabbler(name string) (*Babbler, error) { + res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, name) if err == nil { - id, _ := res.LastInsertId() - return id, nil + id, err := res.LastInsertId() + if err != nil { + log.Print(err) + return nil, err + } + return &Babbler{ + BabblerId: id, + Name: name, + }, nil } - return -1, err + return nil, err } -func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { - id := int64(-1) - err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) +func (p *BabblerPlugin) getBabbler(name string) (*Babbler, error) { + var bblr Babbler + err := p.db.QueryRowx(`select * from babblers where babbler = ? LIMIT 1;`, name).StructScan(&bblr) if err != nil { + if err == sql.ErrNoRows { + log.Printf("failed to find babbler") + return nil, NO_BABBLER + } + log.Printf("encountered problem in babbler lookup") + log.Print(err) + return nil, err } - return id, err + return &bblr, nil } -func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) { - id, err := p.getBabbler(babbler) - if err != nil { - id, err = p.makeBabbler(babbler) +func (p *BabblerPlugin) getOrCreateBabbler(name string) (*Babbler, error) { + babbler, err := p.getBabbler(name) + if err == NO_BABBLER { + babbler, err = p.makeBabbler(name) if err != nil { - return id, err + log.Print(err) + return nil, err } - query := fmt.Sprintf(`select tidbit from factoid where fact like '%s quotes';`, babbler) - rows, err := p.db.Query(query) + + rows, err := p.db.Queryx(fmt.Sprintf("select tidbit from factoid where fact like '%s quotes';", babbler.Name)) if err != nil { - //we'll just ignore this but the actual creation succeeded previously - return id, nil + log.Print(err) + return babbler, nil } + defer rows.Close() + + tidbits := []string{} for rows.Next() { var tidbit string err := rows.Scan(&tidbit) + + log.Print(tidbit) + if err != nil { - return id, err + log.Print(err) + return babbler, err } - p.addToMarkovChain(id, tidbit) + tidbits = append(tidbits, tidbit) } + for _, tidbit := range tidbits { + if err = p.addToMarkovChain(babbler, tidbit); err != nil { + log.Print(err) + } + } } - return id, err + return babbler, err } -func (p *BabblerPlugin) getWordId(babblerId int64, word string) (int64, error) { - id := int64(-1) - err := p.db.Get(&id, `select id from babblerWords where babblerId = ? and word = ?`, babblerId, word) - return id, err -} - -func (p *BabblerPlugin) createNewWord(babblerId int64, word string) (int64, error) { - res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0);`, babblerId, word) +func (p *BabblerPlugin) getWord(word string) (*BabblerWord, error) { + var w BabblerWord + err := p.db.QueryRowx(`select * from babblerWords where word = ? LIMIT 1;`, word).StructScan(&w) if err != nil { - return -1, err + if err == sql.ErrNoRows { + return nil, NEVER_SAID + } + return nil, err } - id, _ := res.LastInsertId() - return id, nil + return &w, nil } -func (p *BabblerPlugin) getOrCreateWord(babblerId int64, word string) (int64, error) { - id, err := p.getWordId(babblerId, word) +func (p *BabblerPlugin) createNewWord(word string) (*BabblerWord, error) { + res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word) if err != nil { - return p.createNewWord(babblerId, word) + log.Print(err) + return nil, err } - return id, err -} - -func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) { - id, err := p.getOrCreateWord(babblerId, word) + id, err := res.LastInsertId() if err != nil { - return -1, err + log.Print(err) + return nil, err } - _, err = p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, id) + return &BabblerWord{ + WordId: id, + Word: word, + }, nil +} + +func (p *BabblerPlugin) getOrCreateWord(word string) (*BabblerWord, error) { + if w, err := p.getWord(word); err == NEVER_SAID { + return p.createNewWord(word) + } else { + if err != nil { + log.Print(err) + } + return w, err + } +} + +func (p *BabblerPlugin) getBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { + w, err := p.getWord(word) if err != nil { - return -1, err + return nil, err } - return id, nil + + var node BabblerNode + err = p.db.QueryRowx(`select * from babblerNodes where babblerId = ? and wordId = ? LIMIT 1;`, babbler.BabblerId, w.WordId).StructScan(&node) + if err != nil { + if err == sql.ErrNoRows { + return nil, NEVER_SAID + } + return nil, err + } + return &node, nil } -func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, error) { - id := int64(-1) - err := p.db.Get(&id, `select id from babblerArcs where fromWordId = ? and toWordId = ?`, fromWordId, toWordId) - return id, err +func (p *BabblerPlugin) createBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { + w, err := p.getOrCreateWord(word) + if err != nil { + log.Print(err) + return nil, err + } + + res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babbler.BabblerId, w.WordId) + if err != nil { + log.Print(err) + return nil, err + } + + id, err := res.LastInsertId() + if err != nil { + log.Print(err) + return nil, err + } + + return &BabblerNode{ + NodeId: id, + WordId: w.WordId, + Root: 0, + RootFrequency: 0, + }, nil } -func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, error) { - res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromWordId = ? and toWordId = ?`, fromWordId, toWordId) +func (p *BabblerPlugin) getOrCreateBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) { + node, err := p.getBabblerNode(babbler, word) + if err != nil { + return p.createBabblerNode(babbler, word) + } + return node, nil +} + +func (p *BabblerPlugin) incrementRootWordFrequency(babbler *Babbler, word string) (*BabblerNode, error) { + node, err := p.getOrCreateBabblerNode(babbler, word) + if err != nil { + log.Print(err) + return nil, err + } + _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, node.NodeId) + if err != nil { + log.Print(err) + return nil, err + } + node.RootFrequency += 1 + return node, nil +} + +func (p *BabblerPlugin) getBabblerArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) { + var arc BabblerArc + err := p.db.QueryRowx(`select * from babblerArcs where fromNodeId = ? and toNodeId = ?;`, fromNode.NodeId, toNode.NodeId).StructScan(&arc) + if err != nil { + if err == sql.ErrNoRows { + return nil, NEVER_SAID + } + return nil, err + } + return &arc, nil +} + +func (p *BabblerPlugin) incrementWordArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) { + res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromNodeId = ? and toNodeId = ?;`, fromNode.NodeId, toNode.NodeId) + if err != nil { + log.Print(err) + return nil, err + } + affectedRows := int64(0) if err == nil { affectedRows, _ = res.RowsAffected() } - if err != nil || affectedRows == 0 { - res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId) + if affectedRows == 0 { + res, err = p.db.Exec(`insert into babblerArcs (fromNodeId, toNodeId, frequency) values (?, ?, 1);`, fromNode.NodeId, toNode.NodeId) if err != nil { - - return -1, err + log.Print(err) + return nil, err } } - id, _ := res.LastInsertId() - return id, nil + return p.getBabblerArc(fromNode, toNode) } -func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error) { - return p.incrementWordArc(wordId, -1) -} - -func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) { - toWordId, err := p.getOrCreateWord(babblerId, toWord) +func (p *BabblerPlugin) incrementFinalWordArcHelper(babbler *Babbler, node *BabblerNode) (*BabblerArc, error) { + nextNode, err := p.getOrCreateBabblerNode(babbler, " ") if err != nil { - return -1, err + return nil, err } - _, err = p.incrementWordArc(fromWordId, toWordId) - if err != nil { - return -1, err - } - return toWordId, nil + return p.incrementWordArc(node, nextNode) } -func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { +func (p *BabblerPlugin) addToMarkovChain(babbler *Babbler, phrase string) error { words := strings.Fields(strings.ToLower(phrase)) - id, err := p.incrementRootWordFrequency(babblerId, words[0]) + if len(words) <= 0 { + return nil + } + + curNode, err := p.incrementRootWordFrequency(babbler, words[0]) if err != nil { - return + log.Print(err) + return err } for i := 1; i < len(words); i++ { - id, err = p.incrementWordArcHelper(babblerId, id, words[i]) + nextNode, err := p.getOrCreateBabblerNode(babbler, words[i]) if err != nil { - return + log.Print(err) + return err } + _, err = p.incrementWordArc(curNode, nextNode) + if err != nil { + log.Print(err) + return err + } + curNode = nextNode } - _, err = p.incrementFinalWordArcHelper(id) + _, err = p.incrementFinalWordArcHelper(babbler, curNode) + return err } -func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) { - query := fmt.Sprintf("select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1", babblerId) - rows, err := p.db.Query(query) +func (p *BabblerPlugin) getWeightedRootNode(babbler *Babbler) (*BabblerNode, *BabblerWord, error) { + rows, err := p.db.Queryx(`select * from babblerNodes where babblerId = ? and root = 1;`, babbler.BabblerId) if err != nil { - return -1, "", err + log.Print(err) + return nil, nil, err } defer rows.Close() - idToWord := map[int64]string{} - idToFreq := map[int64]int64{} + rootNodes := []*BabblerNode{} total := int64(0) for rows.Next() { - var id int64 - var word string - var rootFrequency int64 - err = rows.Scan(&id, &word, &rootFrequency) + var node BabblerNode + err = rows.StructScan(&node) if err != nil { - return -1, "", err + log.Print(err) + return nil, nil, err } + rootNodes = append(rootNodes, &node) + total += node.RootFrequency + } - total += rootFrequency - idToFreq[id] = rootFrequency - idToWord[id] = word + if len(rootNodes) == 0 { + return nil, nil, SAID_NOTHING } which := rand.Int63n(total) total = 0 - for id, freq := range idToFreq { - if total+freq >= which { - return id, idToWord[id], nil - } - total += freq - } - log.Fatalf("shouldn't happen") - return -1, "", errors.New("failed to find weighted root word") -} - -func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, error) { - query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId) - rows, err := p.db.Query(query) - if err != nil { - return -1, "", err - } - defer rows.Close() - - idToFreq := map[int64]int64{} - total := int64(0) - - for rows.Next() { - var toWordId int64 - var frequency int64 - err = rows.Scan(&toWordId, &frequency) - if err != nil { - return -1, "", err - } - total += frequency - idToFreq[toWordId] = frequency - } - - which := rand.Int63n(total) - total = 0 - for id, freq := range idToFreq { - if total+freq >= which { - if id < 0 { - return -1, "", nil - } - - var word string - err := p.db.Get(&word, `select word from babblerWords where id = ?`, id) + for _, node := range rootNodes { + total += node.RootFrequency + if total >= which { + var w BabblerWord + err := p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w) if err != nil { - return -1, "", err + log.Print(err) + return nil, nil, err } - return id, word, nil + return node, &w, nil } - total +=freq + } log.Fatalf("shouldn't happen") - return -1, "", errors.New("failed to find weighted next word") + return nil, nil, errors.New("failed to find weighted root word") +} + +func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode, *BabblerWord, error) { + rows, err := p.db.Queryx(`select * from babblerArcs where fromNodeId = ?;`, fromNode.NodeId) + if err != nil { + log.Print(err) + return nil, nil, err + } + defer rows.Close() + + arcs := []*BabblerArc{} + total := int64(0) + for rows.Next() { + var arc BabblerArc + err = rows.StructScan(&arc) + if err != nil { + log.Print(err) + return nil, nil, err + } + arcs = append(arcs, &arc) + total += arc.Frequency + } + + if len(arcs) == 0 { + return nil, nil, errors.New("missing arcs") + } + + which := rand.Int63n(total) + total = 0 + for _, arc := range arcs { + + total += arc.Frequency + + if total >= which { + var node BabblerNode + err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, arc.ToNodeId).StructScan(&node) + if err != nil { + log.Print(err) + return nil, nil, err + } + + var w BabblerWord + err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w) + if err != nil { + log.Print(err) + return nil, nil, err + } + return &node, &w, nil + } + + } + log.Fatalf("shouldn't happen") + return nil, nil, errors.New("failed to find weighted next word") } func (p *BabblerPlugin) babble(who string) (string, error) { return p.babbleSeed(who, []string{}) } -func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error) { - babblerId, err := p.getBabbler(babbler) +func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, error) { + babbler, err := p.getBabbler(babblerName) if err != nil { + log.Print(err) return "", nil } words := seed - var curWordId int64 + var curNode *BabblerNode + var curWord *BabblerWord if len(seed) == 0 { - id, word, err := p.getWeightedRootWord(babblerId) + curNode, curWord, err = p.getWeightedRootNode(babbler) if err != nil { + log.Print(err) return "", err } - curWordId = id - words = append(words, word) + words = append(words, curWord.Word) } else { - id, err := p.getWordId(babblerId, seed[0]) + curNode, err = p.getBabblerNode(babbler, seed[0]) if err != nil { + log.Print(err) return "", err } - curWordId = id for i := 1; i < len(seed); i++ { - nextWordId, err := p.getWordId(babblerId, seed[i]) + nextNode, err := p.getBabblerNode(babbler, seed[i]) if err != nil { + log.Print(err) return "", err } - _, err = p.getWordArcHelper(curWordId, nextWordId) + _, err = p.getBabblerArc(curNode, nextNode) if err != nil { + log.Print(err) return "", err } - curWordId = nextWordId + curNode = nextNode } } for { - id, word, err := p.getWeightedNextWord(curWordId) + curNode, curWord, err = p.getWeightedNextWord(curNode) if err != nil { + log.Print(err) return "", err } - if id < 0 { + if curWord.Word == " " { break } - words = append(words, word) - curWordId = id + words = append(words, curWord.Word) } return strings.TrimSpace(strings.Join(words, " ")), nil } -func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) error { - intoString := "<" + intoName + ">" - otherString := "<" + otherName + ">" - - mapping := map[int64]int64{} - - query := fmt.Sprintf("select id, word, root, rootFrequency from babblerWords where babblerId = %d;", otherId) - rows, err := p.db.Query(query) +func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoName, otherName string) error { + intoNode, err := p.getOrCreateBabblerNode(intoBabbler, "<"+intoName+">") if err != nil { + log.Print(err) + return err + } + otherNode, err := p.getOrCreateBabblerNode(otherBabbler, "<"+otherName+">") + if err != nil { + log.Print(err) + return err + } + + mapping := map[int64]*BabblerNode{} + + rows, err := p.db.Queryx("select * from babblerNodes where babblerId = ?;", otherBabbler.BabblerId) + if err != nil { + log.Print(err) return err } defer rows.Close() - type Word struct { - Id int64 - Word string - Root int64 - RootFrequency int64 - } - - words := []Word{} + nodes := []*BabblerNode{} for rows.Next() { - word := Word{} - err = rows.Scan(&word.Id, &word.Word, &word.Root, &word.RootFrequency) + var node BabblerNode + err = rows.StructScan(&node) if err != nil { + log.Print(err) return err } - words = append(words, word) + nodes = append(nodes, &node) } - for _, word := range words { + for _, node := range nodes { + var res sql.Result - if word.Word == otherString { - word.Word = intoString + if node.NodeId == otherNode.NodeId { + node.WordId = intoNode.WordId } - doInsert := false - wordId := int64(-1) - if word.Root > 0 { - res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ;`, word.RootFrequency, intoId, word.Word) - rowsAffected := int64(0) - if err == nil { - rowsAffected, _ = res.RowsAffected() - } - if err != nil || rowsAffected == 0 { - doInsert = true - } else { - wordId, _ = res.LastInsertId() + if node.Root > 0 { + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ?;`, node.RootFrequency, intoBabbler.BabblerId, node.WordId) + if err != nil { + log.Print(err) } } else { - res, err := p.db.Exec(`update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id;`, word.RootFrequency, intoId, word.Word) + res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ?;`, node.RootFrequency, intoBabbler.BabblerId, node.WordId) if err != nil { - doInsert = true - } else { - wordId, _ = res.LastInsertId() + log.Print(err) } } - if doInsert { - res, err := p.db.Exec(`insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Word, word.Root, word.RootFrequency) + rowsAffected := int64(-1) + if err == nil { + rowsAffected, _ = res.RowsAffected() + } + + if err != nil || rowsAffected == 0 { + res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoBabbler.BabblerId, node.WordId, node.Root, node.RootFrequency) if err != nil { + log.Print(err) return err } - wordId, _ = res.LastInsertId() } - log.Printf("%s %d -> %d\n", word.Word, word.Id, wordId) + var updatedNode BabblerNode + err = p.db.QueryRowx(`select * from babblerNodes where babblerId = ? and wordId = ? LIMIT 1;`, intoBabbler.BabblerId, node.WordId).StructScan(&updatedNode) + if err != nil { + log.Print(err) + return err + } - mapping[word.Id] = wordId + mapping[node.NodeId] = &updatedNode } - type Arc struct { - ToWordId int64 - Frequency int64 - } - - for lookup, newArcStart := range mapping { - query = fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", lookup) - rows, err := p.db.Query(query) + for oldNodeId, newNode := range mapping { + rows, err := p.db.Queryx("select * from babblerArcs where fromNodeId = ?;", oldNodeId) if err != nil { return err } defer rows.Close() - arcs := []Arc{} + arcs := []*BabblerArc{} for rows.Next() { - var arc Arc - err = rows.Scan(&arc.ToWordId, &arc.Frequency) + var arc BabblerArc + err = rows.StructScan(&arc) if err != nil { + log.Print(err) return err } - arcs = append(arcs, arc) + arcs = append(arcs, &arc) } for _, arc := range arcs { - newArcEnd := int64(-1) //handle end arcs - if arc.ToWordId >= 0 { - newArcEnd = mapping[arc.ToWordId] - } - - res, err := p.db.Exec(`update babblerArcs set frequency = frequency + ? where fromWordId = ? and toWordId = ?`, arc.Frequency, newArcStart, newArcEnd) - rowsAffected := int64(0) - if err == nil { - rowsAffected, _ = res.RowsAffected() - } - if err != nil || rowsAffected == 0 { - _, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, ?);`, newArcStart, newArcEnd, arc.Frequency) - if err != nil { - return err - } + _, err := p.incrementWordArc(newNode, mapping[arc.ToNodeId]) + if err != nil { + return err } } } - return nil + + return err } diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index 7263ce7..a568b8b 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -31,8 +31,9 @@ func TestBabblerNoBabbler(t *testing.T) { c.config.Babbler.DefaultUsers = []string{"seabass"} assert.NotNil(t, c) c.Message(makeMessage("!seabass2 says")) - res := assert.Len(t, mb.Messages, 0) + res := assert.Len(t, mb.Messages, 1) assert.True(t, res) + assert.Contains(t, mb.Messages[0], "seabass2 babbler not found") } func TestBabblerNothingSaid(t *testing.T) { @@ -40,8 +41,13 @@ func TestBabblerNothingSaid(t *testing.T) { c := New(mb) c.config.Babbler.DefaultUsers = []string{"seabass"} assert.NotNil(t, c) - c.Message(makeMessage("!seabass says")) - assert.Len(t, mb.Messages, 0) + res := c.Message(makeMessage("initialize babbler for seabass")) + assert.True(t, res) + res = c.Message(makeMessage("!seabass says")) + assert.True(t, res) + assert.Len(t, mb.Messages, 2) + assert.Contains(t, mb.Messages[0], "okay.") + assert.Contains(t, mb.Messages[1], "seabass hasn't said anything yet.") } func TestBabbler(t *testing.T) { @@ -130,7 +136,8 @@ func TestBabblerBadSeed(t *testing.T) { seabass.Body = "This is a long message" c.Message(seabass) c.Message(makeMessage("!seabass says noooo this is bad")) - assert.Len(t, mb.Messages, 0) + assert.Len(t, mb.Messages, 1) + assert.Contains(t, mb.Messages[0], "seabass never said 'noooo this is bad'") } func TestBabblerBadSeed2(t *testing.T) { @@ -146,7 +153,8 @@ func TestBabblerBadSeed2(t *testing.T) { seabass.Body = "This is a long message" c.Message(seabass) c.Message(makeMessage("!seabass says This is a really")) - assert.Len(t, mb.Messages, 0) + assert.Len(t, mb.Messages, 1) + assert.Contains(t, mb.Messages[0], "seabass never said 'this is a really'") } func TestBabblerBatch(t *testing.T) { diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go new file mode 100644 index 0000000..38e2502 --- /dev/null +++ b/plugins/babbler/commands.go @@ -0,0 +1,110 @@ +package babbler + +import ( + "fmt" + "strings" +) + +func (p *BabblerPlugin) initializeBabbler(tokens []string) (string, bool) { + who := tokens[3] + _, err := p.getOrCreateBabbler(who) + if err != nil { + return "babbler initialization failed.", true + } + return "okay.", true +} + +func (p *BabblerPlugin) addToBabbler(babblerName, whatWasSaid string) (string, bool) { + babblerId, err := p.getOrCreateBabbler(babblerName) + if err == nil { + p.addToMarkovChain(babblerId, whatWasSaid) + } + return "", false +} + +func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) { + who := tokens[0] + _, err := p.getBabbler(who) + + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + } else { + + var saying string + if len(tokens) == 2 { + saying, err = p.babble(who) + } else { + saying, err = p.babbleSeed(who, tokens[2:]) + } + + if err != nil { + if err == SAID_NOTHING { + return fmt.Sprintf("%s hasn't said anything yet.", who), true + } else if err == NEVER_SAID { + return fmt.Sprintf("%s never said '%s'", who, strings.Join(tokens[2:], " ")), true + } + } else if saying != "" { + return saying, true + } + } + return "", false +} + +func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) { + who := tokens[3] + babblerId, err := p.getOrCreateBabbler(who) + if err != nil { + return "batch learn failed.", true + } + + body := strings.Join(tokens[4:], " ") + body = strings.ToLower(body) + + for _, a := range strings.Split(body, ".") { + for _, b := range strings.Split(a, "!") { + for _, c := range strings.Split(b, "?") { + for _, d := range strings.Split(c, "\n") { + trimmed := strings.TrimSpace(d) + if trimmed != "" { + p.addToMarkovChain(babblerId, trimmed) + } + } + } + } + } + return "phew that was tiring.", true +} + +func (p *BabblerPlugin) merge(tokens []string) (string, bool) { + if tokens[3] != "into" { + return "try using 'merge babbler [x] into [y]'", true + } + + who := tokens[2] + into := tokens[4] + + if who == into { + return "that's annoying. stop it.", true + } + + whoBabbler, err := p.getBabbler(who) + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + return "merge failed.", true + } + intoBabbler, err := p.getOrCreateBabbler(into) + if err != nil { + return "merge failed.", true + } + + err = p.mergeBabblers(intoBabbler, whoBabbler, into, who) + if err != nil { + return "merge failed.", true + } + + return "mooooiggged", true +}