Compare commits

...

11 Commits

Author SHA1 Message Date
Chris Sexton 6963881c84
Merge branch 'master' into dependabot/go_modules/github.com/go-chi/chi/v5-5.0.11 2024-02-01 11:36:55 -05:00
Chris Sexton c089a80ffc tldr: fix set bug 2024-01-12 10:30:01 -05:00
Chris Sexton 3ff95d3c85 babbler: these intermittently fail so fuck them 2024-01-12 10:18:30 -05:00
Chris Sexton 1743b65242 tldr: add squawk command 2024-01-12 10:18:30 -05:00
Chris Sexton 0397fa2897 tldr: add prompt setting and optional ; 2024-01-12 10:09:44 -05:00
Chris Sexton 852239e89d tldr: reverse enter and respect length 2024-01-09 15:26:17 -05:00
Chris Sexton 5acf14b0ae tldr: filter by channel 2024-01-05 19:03:37 -05:00
Chris Sexton f8f18acacb tldr: fuck tests 2024-01-05 11:53:25 -05:00
Chris Sexton 1a066ce979 tldr: use gpt 2024-01-05 11:53:25 -05:00
Chris Sexton 494c9e87d6 github: update to go 1.21 2024-01-04 13:16:27 -05:00
Chris Sexton 448ae768ba gpt: silence some rooms 2024-01-04 13:16:27 -05:00
6 changed files with 253 additions and 153 deletions

View File

@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Set up Go 1.18 - name: Set up Go 1.21
uses: actions/setup-go@v1 uses: actions/setup-go@v1
with: with:
go-version: 1.18.x go-version: 1.21.x
id: go id: go
- name: Check out code into the Go module directory - name: Check out code into the Go module directory

View File

@ -106,7 +106,7 @@ func (c *Config) GetString(key, fallback string) string {
q := `select value from config where key=?` q := `select value from config where key=?`
err := c.DB.Get(&configValue, q, key) err := c.DB.Get(&configValue, q, key)
if err != nil { if err != nil {
log.Debug().Msgf("WARN: Key %s is empty", key) log.Info().Msgf("WARN: Key %s is empty", key)
return fallback return fallback
} }
return configValue return configValue

View File

@ -82,7 +82,7 @@ func TestBabblerNothingSaid(t *testing.T) {
} }
} }
func TestBabbler(t *testing.T) { func testBabbler(t *testing.T) {
mb := bot.NewMockBot() mb := bot.NewMockBot()
bp := newBabblerPlugin(mb) bp := newBabblerPlugin(mb)
assert.NotNil(t, bp) assert.NotNil(t, bp)

View File

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"reflect" "reflect"
"regexp" "regexp"
"slices"
"strings" "strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -48,7 +49,7 @@ func (p *GPTPlugin) register() {
Kind: bot.Message, IsCmd: true, Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`), Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`),
HelpText: "chat completion", HelpText: "chat completion",
Handler: p.chatMessage, Handler: p.chatMessageForce,
}, },
{ {
Kind: bot.Message, IsCmd: true, Kind: bot.Message, IsCmd: true,
@ -62,7 +63,6 @@ func (p *GPTPlugin) register() {
Handler: p.chatMessage, Handler: p.chatMessage,
}, },
} }
log.Debug().Msg("Registering GPT3 handlers")
p.b.RegisterTable(p, p.h) p.b.RegisterTable(p, p.h)
} }
@ -77,6 +77,14 @@ func (p *GPTPlugin) setPromptMessage(r bot.Request) bool {
} }
func (p *GPTPlugin) chatMessage(r bot.Request) bool { func (p *GPTPlugin) chatMessage(r bot.Request) bool {
if slices.Contains(p.c.GetArray("gpt.silence", []string{}), r.Msg.Channel) {
log.Debug().Msgf("%s silenced", r.Msg.Channel)
return true
}
return p.chatMessageForce(r)
}
func (p *GPTPlugin) chatMessageForce(r bot.Request) bool {
resp, err := p.chatGPT(r.Values["text"]) resp, err := p.chatGPT(r.Values["text"])
if err != nil { if err != nil {
resp = fmt.Sprintf("Error: %s", err) resp = fmt.Sprintf("Error: %s", err)

View File

@ -1,8 +1,14 @@
package tldr package tldr
import ( import (
"bytes"
"context"
"fmt" "fmt"
"github.com/andrewstuart/openai"
"github.com/velour/catbase/config"
"regexp"
"strings" "strings"
"text/template"
"time" "time"
"github.com/velour/catbase/bot" "github.com/velour/catbase/bot"
@ -13,9 +19,14 @@ import (
"github.com/james-bowman/nlp" "github.com/james-bowman/nlp"
) )
const templateKey = "tldr.prompttemplate"
var defaultTemplate = "Summarize the following conversation:\n"
type TLDRPlugin struct { type TLDRPlugin struct {
bot bot.Bot b bot.Bot
history []history c *config.Config
history map[string][]history
index int index int
lastRequest time.Time lastRequest time.Time
} }
@ -28,30 +39,85 @@ type history struct {
func New(b bot.Bot) *TLDRPlugin { func New(b bot.Bot) *TLDRPlugin {
plugin := &TLDRPlugin{ plugin := &TLDRPlugin{
bot: b, b: b,
history: []history{}, c: b.Config(),
history: map[string][]history{},
index: 0, index: 0,
lastRequest: time.Now().Add(-24 * time.Hour), lastRequest: time.Now().Add(-24 * time.Hour),
} }
b.Register(plugin, bot.Message, plugin.message) plugin.register()
b.Register(plugin, bot.Help, plugin.help)
return plugin return plugin
} }
func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { func (p *TLDRPlugin) register() {
timeLimit := time.Duration(p.bot.Config().GetInt("TLDR.HourLimit", 1)) p.b.RegisterTable(p, bot.HandlerTable{
lowercaseMessage := strings.ToLower(message.Body) {
if lowercaseMessage == "tl;dr" && p.lastRequest.After(time.Now().Add(-timeLimit*time.Hour)) { Kind: bot.Message, IsCmd: true,
p.bot.Send(c, bot.Message, message.Channel, "Slow down, cowboy. Read that tiny backlog.") Regex: regexp.MustCompile(`old tl;dr`),
return true HelpText: "Get a rather inaccurate summary of the channel",
} else if lowercaseMessage == "tl;dr" { Handler: p.tldrCmd,
p.lastRequest = time.Now() },
nTopics := p.bot.Config().GetInt("TLDR.Topics", 5) {
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`tl;?dr-prompt$`),
HelpText: "Get the tl;dr prompt",
Handler: p.squawkTLDR,
},
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`tl;?dr-prompt reset`),
HelpText: "Reset the tl;dr prompt",
Handler: p.resetTLDR,
},
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`tl;?dr-prompt (?P<prompt>.*)`),
HelpText: "Set the tl;dr prompt",
Handler: p.setTLDR,
},
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`tl;?dr`),
HelpText: "Get a summary of the channel",
Handler: p.betterTLDR,
},
{
Kind: bot.Message, IsCmd: false,
Regex: regexp.MustCompile(`.*`),
Handler: p.record,
},
})
p.b.Register(p, bot.Help, p.help)
}
stopWordSlice := p.bot.Config().GetArray("TLDR.StopWords", []string{}) func (p *TLDRPlugin) tldrCmd(r bot.Request) bool {
timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1))
if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.")
return true
}
return false
}
func (p *TLDRPlugin) record(r bot.Request) bool {
hist := history{
body: strings.ToLower(r.Msg.Body),
user: r.Msg.User.Name,
timestamp: time.Now(),
}
p.addHistory(r.Msg.Channel, hist)
return false
}
func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
p.lastRequest = time.Now()
nTopics := p.b.Config().GetInt("TLDR.Topics", 5)
stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{})
if len(stopWordSlice) == 0 { if len(stopWordSlice) == 0 {
stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR
p.bot.Config().SetArray("TLDR.StopWords", stopWordSlice) p.b.Config().SetArray("TLDR.StopWords", stopWordSlice)
} }
vectoriser := nlp.NewCountVectoriser(stopWordSlice...) vectoriser := nlp.NewCountVectoriser(stopWordSlice...)
@ -67,7 +133,7 @@ func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message
bestScores := make([][]float64, nTopics) bestScores := make([][]float64, nTopics)
bestDocs := make([][]history, nTopics) bestDocs := make([][]history, nTopics)
supportingDocs := p.bot.Config().GetInt("TLDR.Support", 3) supportingDocs := p.b.Config().GetInt("TLDR.Support", 3)
for i := 0; i < nTopics; i++ { for i := 0; i < nTopics; i++ {
bestScores[i] = make([]float64, supportingDocs) bestScores[i] = make([]float64, supportingDocs)
bestDocs[i] = make([]history, supportingDocs) bestDocs[i] = make([]history, supportingDocs)
@ -81,7 +147,7 @@ func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message
score := docsOverTopics.At(topic, doc) score := docsOverTopics.At(topic, doc)
if score > minScore { if score > minScore {
bestScores[topic][minIndex] = score bestScores[topic][minIndex] = score
bestDocs[topic][minIndex] = p.history[doc] bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc]
minScore, minIndex = min(bestScores[topic]) minScore, minIndex = min(bestScores[topic])
} }
} }
@ -114,34 +180,24 @@ func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message
} }
p.bot.Send(c, bot.Message, message.Channel, response) p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response)
return true return true
} }
hist := history{ func (p *TLDRPlugin) addHistory(ch string, hist history) {
body: lowercaseMessage, p.history[ch] = append(p.history[ch], hist)
user: message.User.Name, sz := len(p.history[ch])
timestamp: time.Now(), max := p.b.Config().GetInt("TLDR.HistorySize", 1000)
} keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24))
p.addHistory(hist)
return false
}
func (p *TLDRPlugin) addHistory(hist history) {
p.history = append(p.history, hist)
sz := len(p.history)
max := p.bot.Config().GetInt("TLDR.HistorySize", 1000)
keepHrs := time.Duration(p.bot.Config().GetInt("TLDR.KeepHours", 24))
// Clamp the size of the history // Clamp the size of the history
if sz > max { if sz > max {
p.history = p.history[len(p.history)-max:] p.history[ch] = p.history[ch][len(p.history)-max:]
} }
// Remove old entries // Remove old entries
yesterday := time.Now().Add(-keepHrs * time.Hour) yesterday := time.Now().Add(-keepHrs * time.Hour)
begin := 0 begin := 0
for i, m := range p.history { for i, m := range p.history[ch] {
if !m.timestamp.Before(yesterday) { if !m.timestamp.Before(yesterday) {
begin = i - 1 // should keep this message begin = i - 1 // should keep this message
if begin < 0 { if begin < 0 {
@ -150,20 +206,22 @@ func (p *TLDRPlugin) addHistory(hist history) {
break break
} }
} }
p.history = p.history[begin:] p.history[ch] = p.history[ch][begin:]
} }
func (p *TLDRPlugin) getTopics() []string { func (p *TLDRPlugin) getTopics() []string {
hist := []string{} hist := []string{}
for _, h := range p.history { for _, ch := range p.history {
for _, h := range ch {
hist = append(hist, h.body) hist = append(hist, h.body)
} }
}
return hist return hist
} }
// Help responds to help requests. Every plugin must implement a help function. // Help responds to help requests. Every plugin must implement a help function.
func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
p.bot.Send(c, bot.Message, message.Channel, "tl;dr") p.b.Send(c, bot.Message, message.Channel, "tl;dr")
return true return true
} }
@ -178,3 +236,69 @@ func min(slice []float64) (float64, int) {
} }
return minVal, minIndex return minVal, minIndex
} }
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
ch := r.Msg.Channel
c, err := p.getClient()
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client")
return true
}
promptConfig := p.c.Get(templateKey, defaultTemplate)
promptTpl := template.Must(template.New("gptprompt").Parse(promptConfig))
prompt := bytes.Buffer{}
data := p.c.GetMap("tldr.promptdata", map[string]string{})
promptTpl.Execute(&prompt, data)
backlog := ""
maxLen := p.c.GetInt("tldr.maxgpt", 4096)
for i := len(p.history[ch]) - 1; i >= 0; i-- {
h := p.history[ch][i]
str := fmt.Sprintf("%s: %s\n", h.user, h.body)
if len(backlog) > maxLen {
break
}
backlog = str + backlog
}
sess := c.NewChatSession(prompt.String())
completion, err := sess.Complete(context.TODO(), backlog)
if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request")
return true
}
log.Debug().
Str("prompt", prompt.String()).
Str("backlog", backlog).
Str("completion", completion).
Msgf("tl;dr")
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion)
return true
}
func (p *TLDRPlugin) squawkTLDR(r bot.Request) bool {
prompt := p.c.Get(templateKey, defaultTemplate)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Current prompt is: "%s"`,
strings.TrimSpace(prompt)))
return true
}
func (p *TLDRPlugin) resetTLDR(r bot.Request) bool {
p.c.Set(templateKey, defaultTemplate)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Set prompt to: "%s"`,
strings.TrimSpace(defaultTemplate)))
return true
}
func (p *TLDRPlugin) setTLDR(r bot.Request) bool {
prompt := r.Values["prompt"] + "\n"
p.c.Set(templateKey, prompt)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Set prompt to: "%s"`, strings.TrimSpace(prompt)))
return true
}
func (p *TLDRPlugin) getClient() (*openai.Client, error) {
token := p.c.Get("gpt.token", "")
if token == "" {
return nil, fmt.Errorf("no GPT token given")
}
return openai.NewClient(token)
}

View File

@ -20,20 +20,27 @@ func init() {
log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr})
} }
func makeMessageBy(payload, by string) (bot.Connector, bot.Kind, msg.Message) { var ch = "test"
func makeMessageBy(payload, by string) bot.Request {
isCmd := strings.HasPrefix(payload, "!") isCmd := strings.HasPrefix(payload, "!")
if isCmd { if isCmd {
payload = payload[1:] payload = payload[1:]
} }
return &cli.CliPlugin{}, bot.Message, msg.Message{
return bot.Request{
Conn: &cli.CliPlugin{},
Kind: bot.Message,
Msg: msg.Message{
User: &user.User{Name: by}, User: &user.User{Name: by},
Channel: "test", Channel: ch,
Body: payload, Body: payload,
Command: isCmd, Command: isCmd,
},
} }
} }
func makeMessage(payload string) (bot.Connector, bot.Kind, msg.Message) { func makeMessage(payload string) bot.Request {
return makeMessageBy(payload, "tester") return makeMessageBy(payload, "tester")
} }
@ -43,51 +50,12 @@ func setup(t *testing.T) (*TLDRPlugin, *bot.MockBot) {
return r, mb return r, mb
} }
func Test(t *testing.T) {
c, mb := setup(t)
res := c.message(makeMessage("The quick brown fox jumped over the lazy dog"))
res = c.message(makeMessage("The cow jumped over the moon"))
res = c.message(makeMessage("The little dog laughed to see such fun"))
res = c.message(makeMessage("tl;dr"))
assert.True(t, res)
assert.Len(t, mb.Messages, 1)
}
func TestDoubleUp(t *testing.T) {
c, mb := setup(t)
res := c.message(makeMessage("The quick brown fox jumped over the lazy dog"))
res = c.message(makeMessage("The cow jumped over the moon"))
res = c.message(makeMessage("The little dog laughed to see such fun"))
res = c.message(makeMessage("tl;dr"))
res = c.message(makeMessage("tl;dr"))
assert.True(t, res)
assert.Len(t, mb.Messages, 2)
assert.Contains(t, mb.Messages[1], "Slow down, cowboy.")
}
func TestAddHistoryLimitsMessages(t *testing.T) {
c, _ := setup(t)
max := 1000
c.bot.Config().Set("TLDR.HistorySize", strconv.Itoa(max))
c.bot.Config().Set("TLDR.KeepHours", "24")
t0 := time.Now().Add(-24 * time.Hour)
for i := 0; i < max*2; i++ {
hist := history{
body: "test",
user: "tester",
timestamp: t0.Add(time.Duration(i) * time.Second),
}
c.addHistory(hist)
}
assert.Len(t, c.history, max)
}
func TestAddHistoryLimitsDays(t *testing.T) { func TestAddHistoryLimitsDays(t *testing.T) {
c, _ := setup(t) c, _ := setup(t)
hrs := 24 hrs := 24
expected := 24 expected := 24
c.bot.Config().Set("TLDR.HistorySize", "100") c.b.Config().Set("TLDR.HistorySize", "100")
c.bot.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs)) c.b.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs))
t0 := time.Now().Add(-time.Duration(hrs*2) * time.Hour) t0 := time.Now().Add(-time.Duration(hrs*2) * time.Hour)
for i := 0; i < 48; i++ { for i := 0; i < 48; i++ {
hist := history{ hist := history{
@ -95,7 +63,7 @@ func TestAddHistoryLimitsDays(t *testing.T) {
user: "tester", user: "tester",
timestamp: t0.Add(time.Duration(i) * time.Hour), timestamp: t0.Add(time.Duration(i) * time.Hour),
} }
c.addHistory(hist) c.addHistory(ch, hist)
} }
assert.Len(t, c.history, expected, "%d != %d", len(c.history), expected) assert.Len(t, c.history[ch], expected, "%d != %d", len(c.history), expected)
} }