2016-05-09 20:45:02 +00:00
// © 2013 the CatBase Authors under the WTFPL. See AUTHORS for the list of authors.
package babbler
import (
2017-05-09 23:41:40 +00:00
"errors"
2016-05-09 20:45:02 +00:00
"fmt"
2017-05-09 23:41:40 +00:00
"log"
2016-05-09 20:45:02 +00:00
"math/rand"
"strings"
"github.com/jmoiron/sqlx"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/bot/msg"
2016-05-11 01:15:52 +00:00
"github.com/velour/catbase/config"
2016-05-09 20:45:02 +00:00
)
type BabblerPlugin struct {
2017-05-09 23:41:40 +00:00
Bot bot . Bot
db * sqlx . DB
config * config . Config
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
func New ( bot bot . Bot ) * BabblerPlugin {
log . SetFlags ( log . LstdFlags | log . Lshortfile )
if bot . DBVersion ( ) == 1 {
if _ , err := bot . DB ( ) . Exec ( ` create table if not exists babblers (
id integer primary key ,
babbler string
) ; ` ) ; err != nil {
log . Fatal ( err )
}
if _ , err := bot . DB ( ) . Exec ( ` create table if not exists babblerWords (
id integer primary key ,
babblerId integer ,
word string ,
root integer ,
rootFrequency integer
) ; ` ) ; err != nil {
log . Fatal ( err )
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
if _ , err := bot . DB ( ) . Exec ( ` create table if not exists babblerArcs (
id integer primary key ,
fromWordId integer ,
toWordId interger ,
frequency integer
) ; ` ) ; err != nil {
log . Fatal ( err )
}
}
2016-05-09 20:45:02 +00:00
plugin := & BabblerPlugin {
2017-05-09 23:41:40 +00:00
Bot : bot ,
db : bot . DB ( ) ,
config : bot . Config ( ) ,
2016-05-09 20:45:02 +00:00
}
2016-05-11 01:15:52 +00:00
return plugin
}
2016-05-09 20:45:02 +00:00
func ( p * BabblerPlugin ) Message ( message msg . Message ) bool {
lowercase := strings . ToLower ( message . Body )
tokens := strings . Fields ( lowercase )
2016-10-19 23:34:16 +00:00
numTokens := len ( tokens )
2016-05-09 20:45:02 +00:00
2016-10-19 23:34:16 +00:00
if numTokens >= 2 && tokens [ 1 ] == "says" {
2017-05-09 23:41:40 +00:00
who := tokens [ 0 ]
_ , err := p . getBabbler ( who )
if err != nil {
return false
}
2016-10-19 23:34:16 +00:00
var saying string
if len ( tokens ) == 2 {
2017-05-09 23:41:40 +00:00
saying , _ = p . babble ( who )
2016-10-19 23:34:16 +00:00
} else {
2017-05-09 23:41:40 +00:00
saying , _ = p . babbleSeed ( who , tokens [ 2 : ] )
2016-10-19 23:34:16 +00:00
}
2017-05-09 23:41:40 +00:00
2017-04-12 13:02:00 +00:00
if saying != "" {
2016-10-19 23:34:16 +00:00
p . Bot . SendMessage ( message . Channel , saying )
2017-05-09 23:41:40 +00:00
return true
2016-05-11 16:11:21 +00:00
}
2016-05-11 17:07:16 +00:00
} else if len ( tokens ) == 4 && strings . Index ( lowercase , "initialize babbler for " ) == 0 {
who := tokens [ 3 ]
2017-05-09 23:41:40 +00:00
_ , err := p . getOrCreateBabbler ( who )
if err != nil {
p . Bot . SendMessage ( message . Channel , "babbler initialization failed." )
2016-05-11 01:15:52 +00:00
return true
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
p . Bot . SendMessage ( message . Channel , "Okay." )
return true
2016-05-11 17:07:16 +00:00
} else if strings . Index ( lowercase , "batch learn for " ) == 0 {
who := tokens [ 3 ]
2017-05-09 23:41:40 +00:00
babblerId , err := p . getOrCreateBabbler ( who )
if err != nil {
p . Bot . SendMessage ( message . Channel , "batch learn failed." )
return true
2016-05-11 17:07:16 +00:00
}
body := strings . Join ( tokens [ 4 : ] , " " )
body = strings . ToLower ( body )
for _ , a := range strings . Split ( body , "." ) {
for _ , b := range strings . Split ( a , "!" ) {
for _ , c := range strings . Split ( b , "?" ) {
for _ , d := range strings . Split ( c , "\n" ) {
trimmed := strings . TrimSpace ( d )
if trimmed != "" {
2017-05-09 23:41:40 +00:00
p . addToMarkovChain ( babblerId , trimmed )
2016-05-11 17:07:16 +00:00
}
}
}
}
}
2017-05-09 23:41:40 +00:00
p . Bot . SendMessage ( message . Channel , "phew that was tiring." )
2016-05-11 17:07:16 +00:00
return true
2016-05-26 15:06:22 +00:00
} else if len ( tokens ) == 5 && strings . Index ( lowercase , "merge babbler" ) == 0 {
if tokens [ 3 ] != "into" {
p . Bot . SendMessage ( message . Channel , "try using 'merge babbler [x] into [y]'" )
return true
}
who := tokens [ 2 ]
into := tokens [ 4 ]
if who == into {
2017-05-09 23:41:40 +00:00
p . Bot . SendMessage ( message . Channel , "that's annoying. stop it." )
2016-05-26 15:06:22 +00:00
return true
}
2017-05-09 23:41:40 +00:00
whoBabbler , err := p . getBabbler ( who )
if err != nil {
p . Bot . SendMessage ( message . Channel , "merge failed." )
return true
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
intoBabbler , err := p . getOrCreateBabbler ( into )
if err != nil {
p . Bot . SendMessage ( message . Channel , "merge failed." )
return true
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
p . merge ( intoBabbler , whoBabbler , into , who )
2016-05-26 15:06:22 +00:00
p . Bot . SendMessage ( message . Channel , "mooooiggged" )
return true
2016-05-11 16:11:21 +00:00
} else {
2017-05-09 23:41:40 +00:00
babblerId , err := p . getOrCreateBabbler ( message . User . Name )
if err == nil {
p . addToMarkovChain ( babblerId , lowercase )
}
2016-05-09 20:45:02 +00:00
}
return false
}
func ( p * BabblerPlugin ) Help ( channel string , parts [ ] string ) {
2016-05-11 01:15:52 +00:00
p . Bot . SendMessage ( channel , "initialize babbler for seabass\n\nseabass says" )
2016-05-09 20:45:02 +00:00
}
func ( p * BabblerPlugin ) Event ( kind string , message msg . Message ) bool {
return false
}
func ( p * BabblerPlugin ) BotMessage ( message msg . Message ) bool {
return false
}
func ( p * BabblerPlugin ) RegisterWeb ( ) * string {
return nil
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) makeBabbler ( babbler string ) ( int64 , error ) {
res , err := p . db . Exec ( ` insert into babblers (babbler) values (?); ` , babbler )
if err == nil {
id , _ := res . LastInsertId ( )
return id , nil
}
return - 1 , err
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getBabbler ( babbler string ) ( int64 , error ) {
id := int64 ( - 1 )
err := p . db . Get ( & id , ` select id from babblers where babbler = ? ` , babbler )
if err != nil {
}
return id , err
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getOrCreateBabbler ( babbler string ) ( int64 , error ) {
id , err := p . getBabbler ( babbler )
if err != nil {
id , err = p . makeBabbler ( babbler )
if err != nil {
return id , err
}
query := fmt . Sprintf ( ` select tidbit from factoid where fact like '%s quotes'; ` , babbler )
rows , err := p . db . Query ( query )
if err != nil {
//we'll just ignore this but the actual creation succeeded previously
return id , nil
}
for rows . Next ( ) {
var tidbit string
err := rows . Scan ( & tidbit )
if err != nil {
return id , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
p . addToMarkovChain ( id , tidbit )
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
return id , err
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getWordId ( babblerId int64 , word string ) ( int64 , error ) {
id := int64 ( - 1 )
err := p . db . Get ( & id , ` select id from babblerWords where babblerId = ? and word = ? ` , babblerId , word )
return id , err
}
func ( p * BabblerPlugin ) createNewWord ( babblerId int64 , word string ) ( int64 , error ) {
res , err := p . db . Exec ( ` insert into babblerWords (babblerId, word, root, rootFrequency) values (?, ?, 0, 0); ` , babblerId , word )
if err != nil {
return - 1 , err
}
id , _ := res . LastInsertId ( )
return id , nil
}
func ( p * BabblerPlugin ) getOrCreateWord ( babblerId int64 , word string ) ( int64 , error ) {
id , err := p . getWordId ( babblerId , word )
if err != nil {
return p . createNewWord ( babblerId , word )
}
return id , err
}
func ( p * BabblerPlugin ) incrementRootWordFrequency ( babblerId int64 , word string ) ( int64 , error ) {
id , err := p . getOrCreateWord ( babblerId , word )
if err != nil {
return - 1 , err
}
_ , err = p . db . Exec ( ` update babblerWords set rootFrequency = rootFrequency + 1, root = 1 where id = ?; ` , id )
if err != nil {
return - 1 , err
}
return id , nil
}
func ( p * BabblerPlugin ) getWordArcHelper ( fromWordId , toWordId int64 ) ( int64 , error ) {
id := int64 ( - 1 )
err := p . db . Get ( & id , ` select id from babblerArcs where fromWordId = ? and toWordId = ? ` , fromWordId , toWordId )
return id , err
}
func ( p * BabblerPlugin ) incrementWordArc ( fromWordId , toWordId int64 ) ( int64 , error ) {
res , err := p . db . Exec ( ` update babblerArcs set frequency = frequency + 1 where fromWordId = ? and toWordId = ? ` , fromWordId , toWordId )
affectedRows := int64 ( 0 )
if err == nil {
affectedRows , _ = res . RowsAffected ( )
}
if err != nil || affectedRows == 0 {
res , err = p . db . Exec ( ` insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, 1); ` , fromWordId , toWordId )
if err != nil {
return - 1 , err
2016-05-09 20:45:02 +00:00
}
}
2017-05-09 23:41:40 +00:00
id , _ := res . LastInsertId ( )
return id , nil
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) incrementFinalWordArcHelper ( wordId int64 ) ( int64 , error ) {
return p . incrementWordArc ( wordId , - 1 )
}
func ( p * BabblerPlugin ) incrementWordArcHelper ( babblerId , fromWordId int64 , toWord string ) ( int64 , error ) {
toWordId , err := p . getOrCreateWord ( babblerId , toWord )
if err != nil {
return - 1 , err
2016-10-19 23:34:16 +00:00
}
2017-05-09 23:41:40 +00:00
_ , err = p . incrementWordArc ( fromWordId , toWordId )
if err != nil {
return - 1 , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
return toWordId , nil
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) addToMarkovChain ( babblerId int64 , phrase string ) {
words := strings . Fields ( strings . ToLower ( phrase ) )
id , err := p . incrementRootWordFrequency ( babblerId , words [ 0 ] )
2016-05-09 20:45:02 +00:00
if err != nil {
2017-05-09 23:41:40 +00:00
return
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
for i := 1 ; i < len ( words ) ; i ++ {
id , err = p . incrementWordArcHelper ( babblerId , id , words [ i ] )
if err != nil {
return
}
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
_ , err = p . incrementFinalWordArcHelper ( id )
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getWeightedRootWord ( babblerId int64 ) ( int64 , string , error ) {
query := fmt . Sprintf ( "select id, word, rootFrequency from babblerWords where babblerId = %d and root = 1" , babblerId )
rows , err := p . db . Query ( query )
if err != nil {
return - 1 , "" , err
}
defer rows . Close ( )
idToWord := map [ int64 ] string { }
idToFreq := map [ int64 ] int64 { }
total := int64 ( 0 )
for rows . Next ( ) {
var id int64
var word string
var rootFrequency int64
err = rows . Scan ( & id , & word , & rootFrequency )
2016-05-09 20:45:02 +00:00
if err != nil {
2017-05-09 23:41:40 +00:00
return - 1 , "" , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
total += rootFrequency
idToFreq [ id ] = rootFrequency
idToWord [ id ] = word
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
which := rand . Int63n ( total )
total = 0
for id , freq := range idToFreq {
if total + freq >= which {
return id , idToWord [ id ] , nil
}
total += freq
}
log . Fatalf ( "shouldn't happen" )
return - 1 , "" , errors . New ( "failed to find weighted root word" )
2016-10-19 23:34:16 +00:00
}
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) getWeightedNextWord ( fromWordId int64 ) ( int64 , string , error ) {
query := fmt . Sprintf ( "select toWordId, frequency from babblerArcs where fromWordId = %d;" , fromWordId )
rows , err := p . db . Query ( query )
if err != nil {
return - 1 , "" , err
}
defer rows . Close ( )
idToFreq := map [ int64 ] int64 { }
total := int64 ( 0 )
for rows . Next ( ) {
var toWordId int64
var frequency int64
err = rows . Scan ( & toWordId , & frequency )
if err != nil {
return - 1 , "" , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
total += frequency
idToFreq [ toWordId ] = frequency
}
2017-04-07 19:35:18 +00:00
2017-05-09 23:41:40 +00:00
which := rand . Int63n ( total )
total = 0
for id , freq := range idToFreq {
if total + freq >= which {
if id < 0 {
return - 1 , "" , nil
2017-04-07 19:35:18 +00:00
}
2017-05-09 23:41:40 +00:00
var word string
err := p . db . Get ( & word , ` select word from babblerWords where id = ? ` , id )
if err != nil {
return - 1 , "" , err
2017-04-07 19:35:18 +00:00
}
2017-05-09 23:41:40 +00:00
return id , word , nil
2017-04-07 19:35:18 +00:00
}
2017-05-09 23:41:40 +00:00
total += freq
}
log . Fatalf ( "shouldn't happen" )
return - 1 , "" , errors . New ( "failed to find weighted next word" )
}
2017-04-07 19:35:18 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) babble ( who string ) ( string , error ) {
return p . babbleSeed ( who , [ ] string { } )
}
func ( p * BabblerPlugin ) babbleSeed ( babbler string , seed [ ] string ) ( string , error ) {
babblerId , err := p . getBabbler ( babbler )
if err != nil {
return "" , nil
}
words := seed
var curWordId int64
if len ( seed ) == 0 {
id , word , err := p . getWeightedRootWord ( babblerId )
if err != nil {
return "" , err
}
curWordId = id
words = append ( words , word )
} else {
id , err := p . getWordId ( babblerId , seed [ 0 ] )
if err != nil {
return "" , err
}
curWordId = id
for i := 1 ; i < len ( seed ) ; i ++ {
nextWordId , err := p . getWordId ( babblerId , seed [ i ] )
if err != nil {
return "" , err
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
_ , err = p . getWordArcHelper ( curWordId , nextWordId )
if err != nil {
return "" , err
}
curWordId = nextWordId
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
}
2016-05-09 20:45:02 +00:00
2017-05-09 23:41:40 +00:00
for {
id , word , err := p . getWeightedNextWord ( curWordId )
if err != nil {
return "" , err
}
if id < 0 {
break
}
words = append ( words , word )
curWordId = id
2016-05-09 20:45:02 +00:00
}
2017-05-09 23:41:40 +00:00
return strings . TrimSpace ( strings . Join ( words , " " ) ) , nil
2016-05-11 01:15:52 +00:00
}
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
func ( p * BabblerPlugin ) merge ( intoId , otherId int64 , intoName , otherName string ) error {
intoString := "<" + intoName + ">"
otherString := "<" + otherName + ">"
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
mapping := map [ int64 ] int64 { }
query := fmt . Sprintf ( "select id, word, root, rootFrequency from babblerWords where babblerId = %d;" , otherId )
rows , err := p . db . Query ( query )
if err != nil {
return err
}
defer rows . Close ( )
type Word struct {
Id int64
Word string
Root int64
RootFrequency int64
}
words := [ ] Word { }
for rows . Next ( ) {
word := Word { }
err = rows . Scan ( & word . Id , & word . Word , & word . Root , & word . RootFrequency )
if err != nil {
return err
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
words = append ( words , word )
}
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
for _ , word := range words {
if word . Word == otherString {
word . Word = intoString
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
doInsert := false
wordId := int64 ( - 1 )
if word . Root > 0 {
res , err := p . db . Exec ( ` update babblerWords set rootFrequency = rootFrequency + ?, root = 1 where babblerId = ? and word = ? output id ; ` , word . RootFrequency , intoId , word . Word )
rowsAffected := int64 ( 0 )
if err == nil {
rowsAffected , _ = res . RowsAffected ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
if err != nil || rowsAffected == 0 {
doInsert = true
} else {
wordId , _ = res . LastInsertId ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
} else {
res , err := p . db . Exec ( ` update babblerWords set rootFrequency = rootFrequency + ? where babblerId = ? and word = ? output id; ` , word . RootFrequency , intoId , word . Word )
if err != nil {
doInsert = true
} else {
wordId , _ = res . LastInsertId ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
}
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
if doInsert {
res , err := p . db . Exec ( ` insert into babblerWords (babblerId, word, root, rootFrequency) values (?,?,?,?) ; ` , intoId , word . Word , word . Root , word . RootFrequency )
if err != nil {
return err
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
wordId , _ = res . LastInsertId ( )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
log . Printf ( "%s %d -> %d\n" , word . Word , word . Id , wordId )
mapping [ word . Id ] = wordId
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
type Arc struct {
ToWordId int64
Frequency int64
}
for lookup , newArcStart := range mapping {
query = fmt . Sprintf ( "select toWordId, frequency from babblerArcs where fromWordId = %d;" , lookup )
rows , err := p . db . Query ( query )
if err != nil {
return err
}
defer rows . Close ( )
arcs := [ ] Arc { }
2016-05-26 15:06:22 +00:00
2017-05-09 23:41:40 +00:00
for rows . Next ( ) {
var arc Arc
err = rows . Scan ( & arc . ToWordId , & arc . Frequency )
if err != nil {
return err
}
arcs = append ( arcs , arc )
2016-05-26 15:06:22 +00:00
}
2017-05-09 23:41:40 +00:00
for _ , arc := range arcs {
newArcEnd := int64 ( - 1 ) //handle end arcs
if arc . ToWordId >= 0 {
newArcEnd = mapping [ arc . ToWordId ]
}
res , err := p . db . Exec ( ` update babblerArcs set frequency = frequency + ? where fromWordId = ? and toWordId = ? ` , arc . Frequency , newArcStart , newArcEnd )
rowsAffected := int64 ( 0 )
if err == nil {
rowsAffected , _ = res . RowsAffected ( )
}
if err != nil || rowsAffected == 0 {
_ , err = p . db . Exec ( ` insert into babblerArcs (fromWordId, toWordId, frequency) values (?, ?, ?); ` , newArcStart , newArcEnd , arc . Frequency )
if err != nil {
return err
}
2016-05-26 15:06:22 +00:00
}
}
}
2017-05-09 23:41:40 +00:00
return nil
2016-05-26 15:06:22 +00:00
}