mirror of https://github.com/velour/catbase.git
Compare commits
3 Commits
494c9e87d6
...
5acf14b0ae
Author | SHA1 | Date |
---|---|---|
Chris Sexton | 5acf14b0ae | |
Chris Sexton | f8f18acacb | |
Chris Sexton | 1a066ce979 |
|
@ -1,8 +1,14 @@
|
|||
package tldr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/andrewstuart/openai"
|
||||
"github.com/velour/catbase/config"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/velour/catbase/bot"
|
||||
|
@ -14,7 +20,8 @@ import (
|
|||
)
|
||||
|
||||
type TLDRPlugin struct {
|
||||
bot bot.Bot
|
||||
b bot.Bot
|
||||
c *config.Config
|
||||
history []history
|
||||
index int
|
||||
lastRequest time.Time
|
||||
|
@ -22,36 +29,75 @@ type TLDRPlugin struct {
|
|||
|
||||
type history struct {
|
||||
timestamp time.Time
|
||||
channel string
|
||||
user string
|
||||
body string
|
||||
}
|
||||
|
||||
func New(b bot.Bot) *TLDRPlugin {
|
||||
plugin := &TLDRPlugin{
|
||||
bot: b,
|
||||
b: b,
|
||||
c: b.Config(),
|
||||
history: []history{},
|
||||
index: 0,
|
||||
lastRequest: time.Now().Add(-24 * time.Hour),
|
||||
}
|
||||
b.Register(plugin, bot.Message, plugin.message)
|
||||
b.Register(plugin, bot.Help, plugin.help)
|
||||
plugin.register()
|
||||
return plugin
|
||||
}
|
||||
|
||||
func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
|
||||
timeLimit := time.Duration(p.bot.Config().GetInt("TLDR.HourLimit", 1))
|
||||
lowercaseMessage := strings.ToLower(message.Body)
|
||||
if lowercaseMessage == "tl;dr" && p.lastRequest.After(time.Now().Add(-timeLimit*time.Hour)) {
|
||||
p.bot.Send(c, bot.Message, message.Channel, "Slow down, cowboy. Read that tiny backlog.")
|
||||
return true
|
||||
} else if lowercaseMessage == "tl;dr" {
|
||||
p.lastRequest = time.Now()
|
||||
nTopics := p.bot.Config().GetInt("TLDR.Topics", 5)
|
||||
func (p *TLDRPlugin) register() {
|
||||
p.b.RegisterTable(p, bot.HandlerTable{
|
||||
{
|
||||
Kind: bot.Message, IsCmd: true,
|
||||
Regex: regexp.MustCompile(`old tl;dr`),
|
||||
HelpText: "Get a rather inaccurate summary of the channel",
|
||||
Handler: p.tldrCmd,
|
||||
},
|
||||
{
|
||||
Kind: bot.Message, IsCmd: true,
|
||||
Regex: regexp.MustCompile(`tl;dr`),
|
||||
HelpText: "Get a summary of the channel",
|
||||
Handler: p.betterTLDR,
|
||||
},
|
||||
{
|
||||
Kind: bot.Message, IsCmd: false,
|
||||
Regex: regexp.MustCompile(`.*`),
|
||||
Handler: p.record,
|
||||
},
|
||||
})
|
||||
p.b.Register(p, bot.Help, p.help)
|
||||
}
|
||||
|
||||
stopWordSlice := p.bot.Config().GetArray("TLDR.StopWords", []string{})
|
||||
func (p *TLDRPlugin) tldrCmd(r bot.Request) bool {
|
||||
timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1))
|
||||
if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) {
|
||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *TLDRPlugin) record(r bot.Request) bool {
|
||||
hist := history{
|
||||
body: strings.ToLower(r.Msg.Body),
|
||||
user: r.Msg.User.Name,
|
||||
channel: r.Msg.Channel,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
p.addHistory(hist)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
|
||||
p.lastRequest = time.Now()
|
||||
nTopics := p.b.Config().GetInt("TLDR.Topics", 5)
|
||||
|
||||
stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{})
|
||||
if len(stopWordSlice) == 0 {
|
||||
stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR
|
||||
p.bot.Config().SetArray("TLDR.StopWords", stopWordSlice)
|
||||
p.b.Config().SetArray("TLDR.StopWords", stopWordSlice)
|
||||
}
|
||||
|
||||
vectoriser := nlp.NewCountVectoriser(stopWordSlice...)
|
||||
|
@ -67,7 +113,7 @@ func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message
|
|||
bestScores := make([][]float64, nTopics)
|
||||
bestDocs := make([][]history, nTopics)
|
||||
|
||||
supportingDocs := p.bot.Config().GetInt("TLDR.Support", 3)
|
||||
supportingDocs := p.b.Config().GetInt("TLDR.Support", 3)
|
||||
for i := 0; i < nTopics; i++ {
|
||||
bestScores[i] = make([]float64, supportingDocs)
|
||||
bestDocs[i] = make([]history, supportingDocs)
|
||||
|
@ -114,26 +160,16 @@ func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message
|
|||
|
||||
}
|
||||
|
||||
p.bot.Send(c, bot.Message, message.Channel, response)
|
||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
hist := history{
|
||||
body: lowercaseMessage,
|
||||
user: message.User.Name,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
p.addHistory(hist)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *TLDRPlugin) addHistory(hist history) {
|
||||
p.history = append(p.history, hist)
|
||||
sz := len(p.history)
|
||||
max := p.bot.Config().GetInt("TLDR.HistorySize", 1000)
|
||||
keepHrs := time.Duration(p.bot.Config().GetInt("TLDR.KeepHours", 24))
|
||||
max := p.b.Config().GetInt("TLDR.HistorySize", 1000)
|
||||
keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24))
|
||||
// Clamp the size of the history
|
||||
if sz > max {
|
||||
p.history = p.history[len(p.history)-max:]
|
||||
|
@ -163,7 +199,7 @@ func (p *TLDRPlugin) getTopics() []string {
|
|||
|
||||
// Help responds to help requests. Every plugin must implement a help function.
|
||||
func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
|
||||
p.bot.Send(c, bot.Message, message.Channel, "tl;dr")
|
||||
p.b.Send(c, bot.Message, message.Channel, "tl;dr")
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -178,3 +214,43 @@ func min(slice []float64) (float64, int) {
|
|||
}
|
||||
return minVal, minIndex
|
||||
}
|
||||
|
||||
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
||||
c, err := p.getClient()
|
||||
if err != nil {
|
||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
|
||||
return true
|
||||
}
|
||||
promptConfig := p.c.Get("tldr.prompttemplate", "Summarize the following conversation:\n")
|
||||
promptTpl := template.Must(template.New("gptprompt").Parse(promptConfig))
|
||||
prompt := bytes.Buffer{}
|
||||
data := p.c.GetMap("tldr.promptdata", map[string]string{})
|
||||
promptTpl.Execute(&prompt, data)
|
||||
backlog := ""
|
||||
for _, h := range p.history {
|
||||
if h.channel == r.Msg.Channel {
|
||||
backlog += fmt.Sprintf("%s: %s\n", h.user, h.body)
|
||||
}
|
||||
}
|
||||
sess := c.NewChatSession(prompt.String())
|
||||
completion, err := sess.Complete(context.TODO(), backlog)
|
||||
if err != nil {
|
||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request")
|
||||
return true
|
||||
}
|
||||
log.Debug().
|
||||
Str("prompt", prompt.String()).
|
||||
Str("backlog", backlog).
|
||||
Str("completion", completion).
|
||||
Msgf("tl;dr")
|
||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion)
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *TLDRPlugin) getClient() (*openai.Client, error) {
|
||||
token := p.c.Get("gpt.token", "")
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("no GPT token given")
|
||||
}
|
||||
return openai.NewClient(token)
|
||||
}
|
||||
|
|
|
@ -20,20 +20,25 @@ func init() {
|
|||
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||
}
|
||||
|
||||
func makeMessageBy(payload, by string) (bot.Connector, bot.Kind, msg.Message) {
|
||||
func makeMessageBy(payload, by string) bot.Request {
|
||||
isCmd := strings.HasPrefix(payload, "!")
|
||||
if isCmd {
|
||||
payload = payload[1:]
|
||||
}
|
||||
return &cli.CliPlugin{}, bot.Message, msg.Message{
|
||||
|
||||
return bot.Request{
|
||||
Conn: &cli.CliPlugin{},
|
||||
Kind: bot.Message,
|
||||
Msg: msg.Message{
|
||||
User: &user.User{Name: by},
|
||||
Channel: "test",
|
||||
Body: payload,
|
||||
Command: isCmd,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeMessage(payload string) (bot.Connector, bot.Kind, msg.Message) {
|
||||
func makeMessage(payload string) bot.Request {
|
||||
return makeMessageBy(payload, "tester")
|
||||
}
|
||||
|
||||
|
@ -43,51 +48,12 @@ func setup(t *testing.T) (*TLDRPlugin, *bot.MockBot) {
|
|||
return r, mb
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
c, mb := setup(t)
|
||||
res := c.message(makeMessage("The quick brown fox jumped over the lazy dog"))
|
||||
res = c.message(makeMessage("The cow jumped over the moon"))
|
||||
res = c.message(makeMessage("The little dog laughed to see such fun"))
|
||||
res = c.message(makeMessage("tl;dr"))
|
||||
assert.True(t, res)
|
||||
assert.Len(t, mb.Messages, 1)
|
||||
}
|
||||
|
||||
func TestDoubleUp(t *testing.T) {
|
||||
c, mb := setup(t)
|
||||
res := c.message(makeMessage("The quick brown fox jumped over the lazy dog"))
|
||||
res = c.message(makeMessage("The cow jumped over the moon"))
|
||||
res = c.message(makeMessage("The little dog laughed to see such fun"))
|
||||
res = c.message(makeMessage("tl;dr"))
|
||||
res = c.message(makeMessage("tl;dr"))
|
||||
assert.True(t, res)
|
||||
assert.Len(t, mb.Messages, 2)
|
||||
assert.Contains(t, mb.Messages[1], "Slow down, cowboy.")
|
||||
}
|
||||
|
||||
func TestAddHistoryLimitsMessages(t *testing.T) {
|
||||
c, _ := setup(t)
|
||||
max := 1000
|
||||
c.bot.Config().Set("TLDR.HistorySize", strconv.Itoa(max))
|
||||
c.bot.Config().Set("TLDR.KeepHours", "24")
|
||||
t0 := time.Now().Add(-24 * time.Hour)
|
||||
for i := 0; i < max*2; i++ {
|
||||
hist := history{
|
||||
body: "test",
|
||||
user: "tester",
|
||||
timestamp: t0.Add(time.Duration(i) * time.Second),
|
||||
}
|
||||
c.addHistory(hist)
|
||||
}
|
||||
assert.Len(t, c.history, max)
|
||||
}
|
||||
|
||||
func TestAddHistoryLimitsDays(t *testing.T) {
|
||||
c, _ := setup(t)
|
||||
hrs := 24
|
||||
expected := 24
|
||||
c.bot.Config().Set("TLDR.HistorySize", "100")
|
||||
c.bot.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs))
|
||||
c.b.Config().Set("TLDR.HistorySize", "100")
|
||||
c.b.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs))
|
||||
t0 := time.Now().Add(-time.Duration(hrs*2) * time.Hour)
|
||||
for i := 0; i < 48; i++ {
|
||||
hist := history{
|
||||
|
|
Loading…
Reference in New Issue