it was a disaster before, now it is less.

This commit is contained in:
skkiesel 2017-05-17 10:06:10 -04:00
parent e97ce576b8
commit 9e56b54680
1 changed files with 314 additions and 279 deletions

View File

@ -17,18 +17,42 @@ import (
) )
var ( var (
NO_BABBLER = errors.New("babbler not found") NO_BABBLER = errors.New("babbler not found")
SAID_NOTHING = errors.New("hasn't said anything yet") SAID_NOTHING = errors.New("hasn't said anything yet")
NEVER_SAID = errors.New("never said that") NEVER_SAID = errors.New("never said that")
) )
type BabblerPlugin struct { type BabblerPlugin struct {
Bot bot.Bot Bot bot.Bot
db *sqlx.DB db *sqlx.DB
config *config.Config config *config.Config
} }
type Babbler struct {
BabblerId int64 `db:"id"`
Name string `db:"babbler"`
}
type BabblerWord struct {
WordId int64 `db:"id"`
Word string `db:"word"`
}
type BabblerNode struct {
NodeId int64 `db:"id"`
BabblerId int64 `db:"babblerId"`
WordId int64 `db:"wordId"`
Root int64 `db:"root"`
RootFrequency int64 `db:"rootFrequency"`
}
type BabblerArc struct {
ArcId int64 `db:"id"`
FromNodeId int64 `db:"fromNodeId"`
ToNodeId int64 `db:"toNodeId"`
Frequency int64 `db:"frequency"`
}
func New(bot bot.Bot) *BabblerPlugin { func New(bot bot.Bot) *BabblerPlugin {
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
@ -59,8 +83,8 @@ func New(bot bot.Bot) *BabblerPlugin {
if _, err := bot.DB().Exec(`create table if not exists babblerArcs ( if _, err := bot.DB().Exec(`create table if not exists babblerArcs (
id integer primary key, id integer primary key,
fromWordId integer, fromNodeId integer,
toWordId interger, toNodeId interger,
frequency integer frequency integer
);`); err != nil { );`); err != nil {
log.Fatal(err) log.Fatal(err)
@ -73,6 +97,8 @@ func New(bot bot.Bot) *BabblerPlugin {
config: bot.Config(), config: bot.Config(),
} }
plugin.createNewWord("")
return plugin return plugin
} }
@ -119,36 +145,50 @@ func (p *BabblerPlugin) RegisterWeb() *string {
return nil return nil
} }
func (p *BabblerPlugin) makeBabbler(babbler string) (int64, error) { func (p *BabblerPlugin) makeBabbler(name string) (*Babbler, error) {
res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, babbler) res, err := p.db.Exec(`insert into babblers (babbler) values (?);`, name)
if err == nil { if err == nil {
id, _ := res.LastInsertId() id, err := res.LastInsertId()
return id, nil
}
return -1, err
}
func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) {
id := int64(-1)
err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler)
if err == sql.ErrNoRows {
return -1, NO_BABBLER
}
return id, err
}
func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) {
id, err := p.getBabbler(babbler)
if err != nil {
id, err = p.makeBabbler(babbler)
if err != nil { if err != nil {
return id, err log.Print(err)
return nil, err
} }
query := fmt.Sprintf(`select tidbit from factoid where fact like '%s quotes';`, babbler) return &Babbler{
rows, err := p.db.Query(query) BabblerId: id,
Name: name,
}, nil
}
return nil, err
}
func (p *BabblerPlugin) getBabbler(name string) (*Babbler, error) {
var bblr Babbler
err := p.db.QueryRowx(`select * from babblers where babbler = ? LIMIT 1;`, name).StructScan(&bblr)
if err != nil {
if err == sql.ErrNoRows {
log.Printf("failed to find babbler")
return nil, NO_BABBLER
}
log.Printf("encountered problem in babbler lookup")
log.Print(err)
return nil, err
}
return &bblr, nil
}
func (p *BabblerPlugin) getOrCreateBabbler(name string) (*Babbler, error) {
babbler, err := p.getBabbler(name)
if err == NO_BABBLER {
babbler, err = p.makeBabbler(name)
if err != nil { if err != nil {
//we'll just ignore this but the actual creation succeeded previously log.Print(err)
return id, nil return nil, err
}
rows, err := p.db.Queryx(fmt.Sprintf("select tidbit from factoid where fact like '%s quotes';", babbler.Name))
if err != nil {
log.Print(err)
return babbler, nil
} }
defer rows.Close() defer rows.Close()
@ -156,401 +196,411 @@ func (p *BabblerPlugin) getOrCreateBabbler(babbler string) (int64, error) {
for rows.Next() { for rows.Next() {
var tidbit string var tidbit string
err := rows.Scan(&tidbit) err := rows.Scan(&tidbit)
log.Print(tidbit)
if err != nil { if err != nil {
return id, err log.Print(err)
return babbler, err
} }
tidbits = append(tidbits, tidbit) tidbits = append(tidbits, tidbit)
} }
for _, tidbit := range tidbits { for _, tidbit := range tidbits {
p.addToMarkovChain(id, tidbit) if err = p.addToMarkovChain(babbler, tidbit); err != nil {
log.Print(err)
}
} }
} }
return id, err return babbler, err
} }
func (p *BabblerPlugin) getWordId(word string) (int64, error) { func (p *BabblerPlugin) getWord(word string) (*BabblerWord, error) {
id := int64(-1) var w BabblerWord
if err := p.db.Get(&id, `select id from babblerWords where word = ?`, word); err == sql.ErrNoRows { err := p.db.QueryRowx(`select * from babblerWords where word = ? LIMIT 1;`, word).StructScan(&w)
return -1, NEVER_SAID if err != nil {
} else { if err == sql.ErrNoRows {
if err != nil { return nil, NEVER_SAID
log.Print(err)
} }
return id, err return nil, err
} }
return &w, nil
} }
func (p *BabblerPlugin) createNewWord(word string) (int64, error) { func (p *BabblerPlugin) createNewWord(word string) (*BabblerWord, error) {
if res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word); err == nil { res, err := p.db.Exec(`insert into babblerWords (word) values (?);`, word)
id, _ := res.LastInsertId() if err != nil {
return id, nil log.Print(err)
} else { return nil, err
if err != nil {
log.Print(err)
}
return -1, err
} }
id, err := res.LastInsertId()
if err != nil {
log.Print(err)
return nil, err
}
return &BabblerWord{
WordId: id,
Word: word,
}, nil
} }
func (p *BabblerPlugin) getOrCreateWord(word string) (int64, error) { func (p *BabblerPlugin) getOrCreateWord(word string) (*BabblerWord, error) {
if id, err := p.getWordId(word); err == NEVER_SAID { if w, err := p.getWord(word); err == NEVER_SAID {
return p.createNewWord(word) return p.createNewWord(word)
} else {
return id, err
}
}
func (p *BabblerPlugin) getWordNodeId(babblerId int64, word string) (int64, error) {
var err error
wordId := int64(-1)
if wordId, err = p.getWordId(word); err != nil {
return -1, err
}
id := int64(-1)
if err = p.db.Get(&id, `select id from babblerNodes where babblerId = ? and wordId = ?`, babblerId, wordId); err == sql.ErrNoRows {
return -1, NEVER_SAID
} else { } else {
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
return id, err return w, err
} }
} }
func (p *BabblerPlugin) createWordNodeId(babblerId int64, word string) (int64, error) { func (p *BabblerPlugin) getBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) {
var err error w, err := p.getWord(word)
wordId := int64(-1) if err != nil {
if wordId, err = p.getOrCreateWord(word); err != nil { return nil, err
log.Print(err)
return -1, err
} }
id := int64(-1)
res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babblerId, wordId) var node BabblerNode
if err == sql.ErrNoRows { err = p.db.QueryRowx(`select * from babblerNodes where babblerId = ? and wordId = ? LIMIT 1;`, babbler.BabblerId, w.WordId).StructScan(&node)
return -1, NEVER_SAID if err != nil {
} else if err == nil { if err == sql.ErrNoRows {
id, _ = res.LastInsertId() return nil, NEVER_SAID
}
return nil, err
} }
return &node, nil
}
func (p *BabblerPlugin) createBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) {
w, err := p.getOrCreateWord(word)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} return nil, err
return id, err
}
func (p *BabblerPlugin) getOrCreateWordNode(babblerId int64, word string) (int64, error) {
if id, err := p.getWordNodeId(babblerId, word); err == NEVER_SAID {
return p.createWordNodeId(babblerId, word)
} else {
if err != nil {
log.Print(err)
}
return id, err
}
}
func (p *BabblerPlugin) incrementRootWordFrequency(babblerId int64, word string) (int64, error) {
var err error
id := int64(-1)
if id, err = p.getOrCreateWordNode(babblerId, word); err != nil {
return -1, err
} }
if _, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where babblerId = ? and wordId = ?;`, babblerId, id); err != nil { res, err := p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?, ?, 0, 0)`, babbler.BabblerId, w.WordId)
if err != nil {
log.Print(err)
}
return -1, err
}
return id, nil
}
func (p *BabblerPlugin) getWordArcHelper(fromWordId, toWordId int64) (int64, error) {
id := int64(-1)
err := p.db.Get(&id, `select id from babblerArcs where fromWordId = ? and toWordId = ?`, fromWordId, toWordId)
if err == sql.ErrNoRows {
return -1, NEVER_SAID
}
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil, err
} }
return id, err
id, err := res.LastInsertId()
if err != nil {
log.Print(err)
return nil, err
}
return &BabblerNode{
NodeId: id,
WordId: w.WordId,
Root: 0,
RootFrequency: 0,
}, nil
} }
func (p *BabblerPlugin) incrementWordArc(fromWordId, toWordId int64) (int64, error) { func (p *BabblerPlugin) getOrCreateBabblerNode(babbler *Babbler, word string) (*BabblerNode, error) {
res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromWordId = ? and toWordId = ?`, fromWordId, toWordId) node, err := p.getBabblerNode(babbler, word)
if err != nil {
return p.createBabblerNode(babbler, word)
}
return node, nil
}
func (p *BabblerPlugin) incrementRootWordFrequency(babbler *Babbler, word string) (*BabblerNode, error) {
node, err := p.getOrCreateBabblerNode(babbler, word)
if err != nil {
log.Print(err)
return nil, err
}
_, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + 1, root = 1 where id = ?;`, node.NodeId)
if err != nil {
log.Print(err)
return nil, err
}
node.RootFrequency += 1
return node, nil
}
func (p *BabblerPlugin) getBabblerArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) {
var arc BabblerArc
err := p.db.QueryRowx(`select * from babblerArcs where fromNodeId = ? and toNodeId = ?;`, fromNode.NodeId, toNode.NodeId).StructScan(&arc)
if err != nil {
if err == sql.ErrNoRows {
return nil, NEVER_SAID
}
return nil, err
}
return &arc, nil
}
func (p *BabblerPlugin) incrementWordArc(fromNode, toNode *BabblerNode) (*BabblerArc, error) {
res, err := p.db.Exec(`update babblerArcs set frequency = frequency + 1 where fromNodeId = ? and toNodeId = ?;`, fromNode.NodeId, toNode.NodeId)
if err != nil {
log.Print(err)
return nil, err
}
affectedRows := int64(0) affectedRows := int64(0)
if err == nil { if err == nil {
affectedRows, _ = res.RowsAffected() affectedRows, _ = res.RowsAffected()
} }
if err != nil || affectedRows == 0 { if affectedRows == 0 {
res, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1);`, fromWordId, toWordId) res, err = p.db.Exec(`insert into babblerArcs (fromNodeId, toNodeId, frequency) values (?, ?, 1);`, fromNode.NodeId, toNode.NodeId)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return -1, err return nil, err
} }
} }
id, _ := res.LastInsertId() return p.getBabblerArc(fromNode, toNode)
return id, nil
} }
func (p *BabblerPlugin) incrementFinalWordArcHelper(wordId int64) (int64, error) { func (p *BabblerPlugin) incrementFinalWordArcHelper(babbler *Babbler, node *BabblerNode) (*BabblerArc, error) {
return p.incrementWordArc(wordId, -1) nextNode, err := p.getOrCreateBabblerNode(babbler, " ")
}
func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWord string) (int64, error) {
toWordId, err := p.getOrCreateWordNode(babblerId, toWord)
if err != nil { if err != nil {
return -1, err return nil, err
} }
_, err = p.incrementWordArc(fromWordId, toWordId) return p.incrementWordArc(node, nextNode)
if err != nil {
return -1, err
}
return toWordId, nil
} }
func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { func (p *BabblerPlugin) addToMarkovChain(babbler *Babbler, phrase string) error {
words := strings.Fields(strings.ToLower(phrase)) words := strings.Fields(strings.ToLower(phrase))
if len(words) <= 0 { if len(words) <= 0 {
return return nil
} }
id, err := p.incrementRootWordFrequency(babblerId, words[0]) curNode, err := p.incrementRootWordFrequency(babbler, words[0])
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return return err
} }
for i := 1; i < len(words); i++ { for i := 1; i < len(words); i++ {
id, err = p.incrementWordArcHelper(babblerId, id, words[i]) nextNode, err := p.getOrCreateBabblerNode(babbler, words[i])
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return return err
} }
_, err = p.incrementWordArc(curNode, nextNode)
if err != nil {
log.Print(err)
return err
}
curNode = nextNode
} }
_, err = p.incrementFinalWordArcHelper(id) _, err = p.incrementFinalWordArcHelper(babbler, curNode)
return err
} }
func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, error) { func (p *BabblerPlugin) getWeightedRootNode(babbler *Babbler) (*BabblerNode, *BabblerWord, error) {
query := fmt.Sprintf("select id, wordId, rootFrequency from babblerNodes where babblerId = %d and root = 1", babblerId) rows, err := p.db.Queryx(`select * from babblerNodes where babblerId = ? and root = 1;`, babbler.BabblerId)
rows, err := p.db.Query(query)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return -1, "", err return nil, nil, err
} }
defer rows.Close() defer rows.Close()
idToWordId := map[int64]int64{} rootNodes := []*BabblerNode{}
idToFreq := map[int64]int64{}
total := int64(0) total := int64(0)
for rows.Next() { for rows.Next() {
var id int64 var node BabblerNode
var wordId int64 err = rows.StructScan(&node)
var rootFrequency int64
err = rows.Scan(&id, &wordId, &rootFrequency)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return -1, "", err return nil, nil, err
} }
rootNodes = append(rootNodes, &node)
total += rootFrequency total += node.RootFrequency
idToFreq[id] = rootFrequency
idToWordId[id] = wordId
} }
if total == 0 { if len(rootNodes) == 0 {
return -1, "", SAID_NOTHING return nil, nil, SAID_NOTHING
} }
which := rand.Int63n(total) which := rand.Int63n(total)
total = 0 total = 0
for id, freq := range idToFreq { for _, node := range rootNodes {
if total+freq >= which { total += node.RootFrequency
var word string if total >= which {
err := p.db.Get(&word, `select word from babblerWords where id = ?`, idToWordId[id]) var w BabblerWord
err := p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return -1, "", err return nil, nil, err
} }
return id, word, nil return node, &w, nil
} }
total += freq
} }
log.Fatalf("shouldn't happen") log.Fatalf("shouldn't happen")
return -1, "", errors.New("failed to find weighted root word") return nil, nil, errors.New("failed to find weighted root word")
} }
func (p *BabblerPlugin) getWeightedNextWord(fromWordId int64) (int64, string, error) { func (p *BabblerPlugin) getWeightedNextWord(fromNode *BabblerNode) (*BabblerNode, *BabblerWord, error) {
query := fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", fromWordId) rows, err := p.db.Queryx(`select * from babblerArcs where fromNodeId = ?;`, fromNode.NodeId)
rows, err := p.db.Query(query)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return -1, "", err return nil, nil, err
} }
defer rows.Close() defer rows.Close()
idToFreq := map[int64]int64{} arcs := []*BabblerArc{}
total := int64(0) total := int64(0)
for rows.Next() { for rows.Next() {
var toWordId int64 var arc BabblerArc
var frequency int64 err = rows.StructScan(&arc)
err = rows.Scan(&toWordId, &frequency)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return -1, "", err return nil, nil, err
} }
total += frequency arcs = append(arcs, &arc)
idToFreq[toWordId] = frequency total += arc.Frequency
} }
if total == 0 { if len(arcs) == 0 {
return -1, "", errors.New("missing arcs") return nil, nil, errors.New("missing arcs")
} }
which := rand.Int63n(total) which := rand.Int63n(total)
total = 0 total = 0
for id, freq := range idToFreq { for _, arc := range arcs {
if total+freq >= which {
if id < 0 { total += arc.Frequency
return -1, "", nil
if total >= which {
var node BabblerNode
err := p.db.QueryRowx(`select * from babblerNodes where id = ? LIMIT 1;`, arc.ToNodeId).StructScan(&node)
if err != nil {
log.Print(err)
return nil, nil, err
} }
var wordId int64 var w BabblerWord
if err := p.db.Get(&wordId, `select wordId from babblerNodes where id = ?`, id); err != nil { err = p.db.QueryRowx(`select * from babblerWords where id = ? LIMIT 1;`, node.WordId).StructScan(&w)
if err != nil {
log.Print(err) log.Print(err)
return -1, "", err return nil, nil, err
} }
var word string return &node, &w, nil
if err := p.db.Get(&word, `select word from babblerWords where id = ?`, wordId); err != nil {
log.Print(err)
return -1, "", err
}
return id, word, nil
} }
total +=freq
} }
log.Fatalf("shouldn't happen") log.Fatalf("shouldn't happen")
return -1, "", errors.New("failed to find weighted next word") return nil, nil, errors.New("failed to find weighted next word")
} }
func (p *BabblerPlugin) babble(who string) (string, error) { func (p *BabblerPlugin) babble(who string) (string, error) {
return p.babbleSeed(who, []string{}) return p.babbleSeed(who, []string{})
} }
func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error) { func (p *BabblerPlugin) babbleSeed(babblerName string, seed []string) (string, error) {
babblerId, err := p.getBabbler(babbler) babbler, err := p.getBabbler(babblerName)
if err != nil { if err != nil {
log.Print(err)
return "", nil return "", nil
} }
words := seed words := seed
var curWordId int64 var curNode *BabblerNode
var curWord *BabblerWord
if len(seed) == 0 { if len(seed) == 0 {
id, word, err := p.getWeightedRootWord(babblerId) curNode, curWord, err = p.getWeightedRootNode(babbler)
if err != nil { if err != nil {
log.Print(err)
return "", err return "", err
} }
curWordId = id words = append(words, curWord.Word)
words = append(words, word)
} else { } else {
id, err := p.getWordNodeId(babblerId, seed[0]) curNode, err = p.getBabblerNode(babbler, seed[0])
if err != nil { if err != nil {
log.Print(err)
return "", err return "", err
} }
curWordId = id
for i := 1; i < len(seed); i++ { for i := 1; i < len(seed); i++ {
nextWordId, err := p.getWordNodeId(babblerId, seed[i]) nextNode, err := p.getBabblerNode(babbler, seed[i])
if err != nil { if err != nil {
log.Print(err)
return "", err return "", err
} }
_, err = p.getWordArcHelper(curWordId, nextWordId) _, err = p.getBabblerArc(curNode, nextNode)
if err != nil { if err != nil {
log.Print(err)
return "", err return "", err
} }
curWordId = nextWordId curNode = nextNode
} }
} }
for { for {
id, word, err := p.getWeightedNextWord(curWordId) curNode, curWord, err = p.getWeightedNextWord(curNode)
if err != nil { if err != nil {
log.Print(err)
return "", err return "", err
} }
if id < 0 { if curWord.Word == " " {
break break
} }
words = append(words, word) words = append(words, curWord.Word)
curWordId = id
} }
return strings.TrimSpace(strings.Join(words, " ")), nil return strings.TrimSpace(strings.Join(words, " ")), nil
} }
func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error { func (p *BabblerPlugin) mergeBabblers(intoBabbler, otherBabbler *Babbler, intoName, otherName string) error {
intoStringId, err := p.getOrCreateWordNode(intoId, "<" + intoName + ">") intoNode, err := p.getOrCreateBabblerNode(intoBabbler, "<"+intoName+">")
if err != nil { if err != nil {
log.Print(err)
return err return err
} }
otherStringId, err := p.getOrCreateWordNode(otherId, "<" + otherName + ">") otherNode, err := p.getOrCreateBabblerNode(otherBabbler, "<"+otherName+">")
if err != nil { if err != nil {
log.Print(err)
return err return err
} }
log.Print(intoStringId) mapping := map[int64]*BabblerNode{}
log.Print(otherStringId)
mapping := map[int64]int64{} rows, err := p.db.Queryx("select * from babblerNodes where babblerId = ?;", otherBabbler.BabblerId)
query := fmt.Sprintf("select wordId, root, rootFrequency from babblerNodes where babblerId = %d;", otherId)
rows, err := p.db.Query(query)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return err return err
} }
defer rows.Close() defer rows.Close()
type Word struct { nodes := []*BabblerNode{}
Id int64
Word string
Root int64
RootFrequency int64
}
words := []Word{}
for rows.Next() { for rows.Next() {
word := Word{} var node BabblerNode
err = rows.Scan(&word.Id, &word.Root, &word.RootFrequency) err = rows.StructScan(&node)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return err return err
} }
words = append(words, word) nodes = append(nodes, &node)
} }
for _, word := range words { for _, node := range nodes {
if word.Id == otherStringId {
log.Print("HERE")
word.Id = intoStringId
}
var res sql.Result var res sql.Result
if word.Root > 0 { if node.NodeId == otherNode.NodeId {
res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ? output id ;`, word.RootFrequency, intoId, word.Id) node.WordId = intoNode.WordId
}
if node.Root > 0 {
res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and wordId = ?;`, node.RootFrequency, intoBabbler.BabblerId, node.WordId)
if err != nil {
log.Print(err)
}
} else { } else {
res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ? output id;`, word.RootFrequency, intoId, word.Id) res, err = p.db.Exec(`update babblerNodes set rootFrequency = rootFrequency + ? where babblerId = ? and wordId = ?;`, node.RootFrequency, intoBabbler.BabblerId, node.WordId)
if err != nil {
log.Print(err)
}
} }
rowsAffected := int64(-1) rowsAffected := int64(-1)
@ -558,65 +608,50 @@ func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName
rowsAffected, _ = res.RowsAffected() rowsAffected, _ = res.RowsAffected()
} }
wordId := int64(-1)
if err != nil || rowsAffected == 0 { if err != nil || rowsAffected == 0 {
res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoId, word.Id, word.Root, word.RootFrequency) res, err = p.db.Exec(`insert into babblerNodes (babblerId, wordId, root, rootFrequency) values (?,?,?,?) ;`, intoBabbler.BabblerId, node.WordId, node.Root, node.RootFrequency)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return err return err
} }
} }
wordId, _ = res.LastInsertId() var updatedNode BabblerNode
err = p.db.QueryRowx(`select * from babblerNodes where babblerId = ? and wordId = ? LIMIT 1;`, intoBabbler.BabblerId, node.WordId).StructScan(&updatedNode)
if word.Id == intoStringId { if err != nil {
word.Id = otherStringId log.Print(err)
return err
} }
mapping[word.Id] = wordId
mapping[node.NodeId] = &updatedNode
} }
type Arc struct { for oldNodeId, newNode := range mapping {
ToWordId int64 rows, err := p.db.Queryx("select * from babblerArcs where fromNodeId = ?;", oldNodeId)
Frequency int64
}
for lookup, newArcStart := range mapping {
query = fmt.Sprintf("select toWordId, frequency from babblerArcs where fromWordId = %d;", lookup)
rows, err := p.db.Query(query)
if err != nil { if err != nil {
return err return err
} }
defer rows.Close() defer rows.Close()
arcs := []Arc{} arcs := []*BabblerArc{}
for rows.Next() { for rows.Next() {
var arc Arc var arc BabblerArc
err = rows.Scan(&arc.ToWordId, &arc.Frequency) err = rows.StructScan(&arc)
if err != nil { if err != nil {
log.Print(err)
return err return err
} }
arcs = append(arcs, arc) arcs = append(arcs, &arc)
} }
for _, arc := range arcs { for _, arc := range arcs {
newArcEnd := int64(-1) //handle end arcs _, err := p.incrementWordArc(newNode, mapping[arc.ToNodeId])
if arc.ToWordId >= 0 { if err != nil {
newArcEnd = mapping[arc.ToWordId] return err
}
res, err := p.db.Exec(`update babblerArcs set frequency = frequency + ? where fromWordId = ? and toWordId = ?`, arc.Frequency, newArcStart, newArcEnd)
rowsAffected := int64(0)
if err == nil {
rowsAffected, _ = res.RowsAffected()
}
if err != nil || rowsAffected == 0 {
_, err = p.db.Exec(`insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, ?);`, newArcStart, newArcEnd, arc.Frequency)
if err != nil {
return err
}
} }
} }
} }
return nil
return err
} }