diff --git a/plugins/babbler/babbler.go b/plugins/babbler/babbler.go index 4b144b9..49c2aed 100644 --- a/plugins/babbler/babbler.go +++ b/plugins/babbler/babbler.go @@ -16,9 +16,9 @@ import ( "github.com/velour/catbase/config" ) -const ( - NO_BABBLER = "babbler not found" - SAID_NOTHING = "hasn't said anything yet" +var ( + NO_BABBLER = errors.New("babbler not found") + SAID_NOTHING = errors.New("hasn't said anything yet") ) @@ -72,108 +72,26 @@ func (p *BabblerPlugin) Message(message msg.Message) bool { tokens := strings.Fields(lowercase) numTokens := len(tokens) + saidSomething := false + saidWhat := "" + if numTokens >= 2 && tokens[1] == "says" { - who := tokens[0] - _, err := p.getBabbler(who) - - if err != nil { - if err.Error() == NO_BABBLER { - p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s babbler not found.", who)) - return true - } - return false - } - - var saying string - if len(tokens) == 2 { - saying, err = p.babble(who) - } else { - saying, err = p.babbleSeed(who, tokens[2:]) - } - - if err != nil { - if err.Error() == SAID_NOTHING { - p.Bot.SendMessage(message.Channel, fmt.Sprintf("%s hasn't said anything yet.", who)) - return true - } - return false - } - if saying != "" { - p.Bot.SendMessage(message.Channel, saying) - return true - } + saidWhat, saidSomething = p.getBabble(tokens) } else if len(tokens) == 4 && strings.Index(lowercase, "initialize babbler for ") == 0 { - who := tokens[3] - _, err := p.getOrCreateBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "babbler initialization failed.") - return true - } - p.Bot.SendMessage(message.Channel, "okay.") - return true + saidWhat, saidSomething = p.initializeBabbler(tokens) } else if strings.Index(lowercase, "batch learn for ") == 0 { - who := tokens[3] - babblerId, err := p.getOrCreateBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "batch learn failed.") - return true - } - - body := strings.Join(tokens[4:], " ") - body = strings.ToLower(body) - - for _, a := range strings.Split(body, ".") { - for _, b := range strings.Split(a, "!") { - for _, c := range strings.Split(b, "?") { - for _, d := range strings.Split(c, "\n") { - trimmed := strings.TrimSpace(d) - if trimmed != "" { - p.addToMarkovChain(babblerId, trimmed) - } - } - } - } - } - - p.Bot.SendMessage(message.Channel, "phew that was tiring.") - return true + saidWhat, saidSomething = p.batchLearn(tokens) } else if len(tokens) == 5 && strings.Index(lowercase, "merge babbler") == 0 { - if tokens[3] != "into" { - p.Bot.SendMessage(message.Channel, "try using 'merge babbler [x] into [y]'") - return true - } - - who := tokens[2] - into := tokens[4] - - if who == into { - p.Bot.SendMessage(message.Channel, "that's annoying. stop it.") - return true - } - - whoBabbler, err := p.getBabbler(who) - if err != nil { - p.Bot.SendMessage(message.Channel, "merge failed.") - return true - } - intoBabbler, err := p.getOrCreateBabbler(into) - if err != nil { - p.Bot.SendMessage(message.Channel, "merge failed.") - return true - } - - p.merge(intoBabbler, whoBabbler, into, who) - - p.Bot.SendMessage(message.Channel, "mooooiggged") - return true + saidWhat, saidSomething = p.merge(tokens) } else { - babblerId, err := p.getOrCreateBabbler(message.User.Name) - if err == nil { - p.addToMarkovChain(babblerId, lowercase) - } + //this should always return "", false + saidWhat, saidSomething = p.addToBabbler(message.User.Name, lowercase) } - return false + if saidSomething { + p.Bot.SendMessage(message.Channel, saidWhat) + } + return saidSomething } func (p *BabblerPlugin) Help(channel string, parts []string) { @@ -205,7 +123,7 @@ func (p *BabblerPlugin) getBabbler(babbler string) (int64, error) { id := int64(-1) err := p.db.Get(&id, `select id from babblers where babbler = ?`, babbler) if err != nil && err == sql.ErrNoRows { - return -1, errors.New(NO_BABBLER) + return -1, NO_BABBLER } return id, err } @@ -315,6 +233,10 @@ func (p *BabblerPlugin) incrementWordArcHelper(babblerId, fromWordId int64, toWo func (p *BabblerPlugin) addToMarkovChain(babblerId int64, phrase string) { words := strings.Fields(strings.ToLower(phrase)) + if len(words) <= 0 { + return + } + id, err := p.incrementRootWordFrequency(babblerId, words[0]) if err != nil { return @@ -357,7 +279,7 @@ func (p *BabblerPlugin) getWeightedRootWord(babblerId int64) (int64, string, err } if total == 0 { - return -1, "", errors.New(SAID_NOTHING) + return -1, "", SAID_NOTHING } which := rand.Int63n(total) @@ -473,7 +395,7 @@ func (p *BabblerPlugin) babbleSeed(babbler string, seed []string) (string, error return strings.TrimSpace(strings.Join(words, " ")), nil } -func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) error { +func (p *BabblerPlugin) mergeBabblers(intoId, otherId int64, intoName, otherName string) error { intoString := "<" + intoName + ">" otherString := "<" + otherName + ">" @@ -540,8 +462,6 @@ func (p *BabblerPlugin) merge(intoId, otherId int64, intoName, otherName string) wordId, _ = res.LastInsertId() } - log.Printf("%s %d -> %d\n", word.Word, word.Id, wordId) - mapping[word.Id] = wordId } diff --git a/plugins/babbler/commands.go b/plugins/babbler/commands.go new file mode 100644 index 0000000..c52210a --- /dev/null +++ b/plugins/babbler/commands.go @@ -0,0 +1,105 @@ +package babbler + +import ( + "fmt" + "strings" +) + +func (p *BabblerPlugin) initializeBabbler(tokens []string) (string, bool) { + who := tokens[3] + _, err := p.getOrCreateBabbler(who) + if err != nil { + return "babbler initialization failed.", true + } + return "okay.", true +} + +func (p *BabblerPlugin) addToBabbler(babblerName, whatWasSaid string) (string, bool) { + babblerId, err := p.getOrCreateBabbler(babblerName) + if err == nil { + p.addToMarkovChain(babblerId, whatWasSaid) + } + return "", false +} + +func (p *BabblerPlugin) getBabble(tokens []string) (string, bool) { + who := tokens[0] + _, err := p.getBabbler(who) + + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + } else { + + var saying string + if len(tokens) == 2 { + saying, err = p.babble(who) + } else { + saying, err = p.babbleSeed(who, tokens[2:]) + } + + if err != nil { + if err == SAID_NOTHING { + return fmt.Sprintf("%s hasn't said anything yet.", who), true + } + } else if saying != "" { + return saying, true + } + } + return "", false +} + +func (p *BabblerPlugin) batchLearn(tokens []string) (string, bool) { + who := tokens[3] + babblerId, err := p.getOrCreateBabbler(who) + if err != nil { + return "batch learn failed.", true + } + + body := strings.Join(tokens[4:], " ") + body = strings.ToLower(body) + + for _, a := range strings.Split(body, ".") { + for _, b := range strings.Split(a, "!") { + for _, c := range strings.Split(b, "?") { + for _, d := range strings.Split(c, "\n") { + trimmed := strings.TrimSpace(d) + if trimmed != "" { + p.addToMarkovChain(babblerId, trimmed) + } + } + } + } + } + return "phew that was tiring.", true +} + +func (p *BabblerPlugin) merge(tokens []string) (string, bool) { + if tokens[3] != "into" { + return "try using 'merge babbler [x] into [y]'", true + } + + who := tokens[2] + into := tokens[4] + + if who == into { + return "that's annoying. stop it.", true + } + + whoBabbler, err := p.getBabbler(who) + if err != nil { + if err == NO_BABBLER { + return fmt.Sprintf("%s babbler not found.", who), true + } + return "merge failed.", true + } + intoBabbler, err := p.getOrCreateBabbler(into) + if err != nil { + return "merge failed.", true + } + + p.mergeBabblers(intoBabbler, whoBabbler, into, who) + + return "mooooiggged", true +}