diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 889d49d..eafbb6a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.18 + - name: Set up Go 1.21 uses: actions/setup-go@v1 with: - go-version: 1.18.x + go-version: 1.21.x id: go - name: Check out code into the Go module directory diff --git a/config/config.go b/config/config.go index 08f8a64..40b4d0a 100644 --- a/config/config.go +++ b/config/config.go @@ -106,7 +106,7 @@ func (c *Config) GetString(key, fallback string) string { q := `select value from config where key=?` err := c.DB.Get(&configValue, q, key) if err != nil { - log.Debug().Msgf("WARN: Key %s is empty", key) + log.Info().Msgf("WARN: Key %s is empty", key) return fallback } return configValue diff --git a/plugins/babbler/babbler_test.go b/plugins/babbler/babbler_test.go index e3eb061..fc25233 100644 --- a/plugins/babbler/babbler_test.go +++ b/plugins/babbler/babbler_test.go @@ -82,7 +82,7 @@ func TestBabblerNothingSaid(t *testing.T) { } } -func TestBabbler(t *testing.T) { +func testBabbler(t *testing.T) { mb := bot.NewMockBot() bp := newBabblerPlugin(mb) assert.NotNil(t, bp) diff --git a/plugins/gpt/gpt3.go b/plugins/gpt/gpt3.go index 764ca53..1bfee23 100644 --- a/plugins/gpt/gpt3.go +++ b/plugins/gpt/gpt3.go @@ -9,6 +9,7 @@ import ( "net/http" "reflect" "regexp" + "slices" "strings" "github.com/rs/zerolog/log" @@ -48,7 +49,7 @@ func (p *GPTPlugin) register() { Kind: bot.Message, IsCmd: true, Regex: regexp.MustCompile(`(?is)^gpt (?P.*)`), HelpText: "chat completion", - Handler: p.chatMessage, + Handler: p.chatMessageForce, }, { Kind: bot.Message, IsCmd: true, @@ -62,7 +63,6 @@ func (p *GPTPlugin) register() { Handler: p.chatMessage, }, } - log.Debug().Msg("Registering GPT3 handlers") p.b.RegisterTable(p, p.h) } @@ -77,6 +77,14 @@ func (p *GPTPlugin) setPromptMessage(r bot.Request) bool { } func (p *GPTPlugin) chatMessage(r bot.Request) bool { + if slices.Contains(p.c.GetArray("gpt.silence", []string{}), r.Msg.Channel) { + log.Debug().Msgf("%s silenced", r.Msg.Channel) + return true + } + return p.chatMessageForce(r) +} + +func (p *GPTPlugin) chatMessageForce(r bot.Request) bool { resp, err := p.chatGPT(r.Values["text"]) if err != nil { resp = fmt.Sprintf("Error: %s", err) diff --git a/plugins/tldr/tldr.go b/plugins/tldr/tldr.go index 76784d4..60dc036 100644 --- a/plugins/tldr/tldr.go +++ b/plugins/tldr/tldr.go @@ -1,8 +1,14 @@ package tldr import ( + "bytes" + "context" "fmt" + "github.com/andrewstuart/openai" + "github.com/velour/catbase/config" + "regexp" "strings" + "text/template" "time" "github.com/velour/catbase/bot" @@ -13,9 +19,14 @@ import ( "github.com/james-bowman/nlp" ) +const templateKey = "tldr.prompttemplate" + +var defaultTemplate = "Summarize the following conversation:\n" + type TLDRPlugin struct { - bot bot.Bot - history []history + b bot.Bot + c *config.Config + history map[string][]history index int lastRequest time.Time } @@ -28,120 +39,165 @@ type history struct { func New(b bot.Bot) *TLDRPlugin { plugin := &TLDRPlugin{ - bot: b, - history: []history{}, + b: b, + c: b.Config(), + history: map[string][]history{}, index: 0, lastRequest: time.Now().Add(-24 * time.Hour), } - b.Register(plugin, bot.Message, plugin.message) - b.Register(plugin, bot.Help, plugin.help) + plugin.register() return plugin } -func (p *TLDRPlugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { - timeLimit := time.Duration(p.bot.Config().GetInt("TLDR.HourLimit", 1)) - lowercaseMessage := strings.ToLower(message.Body) - if lowercaseMessage == "tl;dr" && p.lastRequest.After(time.Now().Add(-timeLimit*time.Hour)) { - p.bot.Send(c, bot.Message, message.Channel, "Slow down, cowboy. Read that tiny backlog.") - return true - } else if lowercaseMessage == "tl;dr" { - p.lastRequest = time.Now() - nTopics := p.bot.Config().GetInt("TLDR.Topics", 5) - - stopWordSlice := p.bot.Config().GetArray("TLDR.StopWords", []string{}) - if len(stopWordSlice) == 0 { - stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR - p.bot.Config().SetArray("TLDR.StopWords", stopWordSlice) - } - - vectoriser := nlp.NewCountVectoriser(stopWordSlice...) - lda := nlp.NewLatentDirichletAllocation(nTopics) - pipeline := nlp.NewPipeline(vectoriser, lda) - docsOverTopics, err := pipeline.FitTransform(p.getTopics()...) - - if err != nil { - log.Error().Err(err) - return false - } - - bestScores := make([][]float64, nTopics) - bestDocs := make([][]history, nTopics) - - supportingDocs := p.bot.Config().GetInt("TLDR.Support", 3) - for i := 0; i < nTopics; i++ { - bestScores[i] = make([]float64, supportingDocs) - bestDocs[i] = make([]history, supportingDocs) - } - - dr, dc := docsOverTopics.Dims() - for topic := 0; topic < dr; topic++ { - minScore, minIndex := min(bestScores[topic]) - - for doc := 0; doc < dc; doc++ { - score := docsOverTopics.At(topic, doc) - if score > minScore { - bestScores[topic][minIndex] = score - bestDocs[topic][minIndex] = p.history[doc] - minScore, minIndex = min(bestScores[topic]) - } - } - } - - topicsOverWords := lda.Components() - tr, tc := topicsOverWords.Dims() - - vocab := make([]string, len(vectoriser.Vocabulary)) - for k, v := range vectoriser.Vocabulary { - vocab[v] = k - } - - response := "Here you go captain 'too good to read backlog':\n" - - for topic := 0; topic < tr; topic++ { - bestScore := -1. - bestTopic := "" - for word := 0; word < tc; word++ { - score := topicsOverWords.At(topic, word) - if score > bestScore { - bestScore = score - bestTopic = vocab[word] - } - } - response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic) - for i := range bestDocs[topic] { - response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body) - } - - } - - p.bot.Send(c, bot.Message, message.Channel, response) +func (p *TLDRPlugin) register() { + p.b.RegisterTable(p, bot.HandlerTable{ + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`old tl;dr`), + HelpText: "Get a rather inaccurate summary of the channel", + Handler: p.tldrCmd, + }, + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`tl;?dr-prompt$`), + HelpText: "Get the tl;dr prompt", + Handler: p.squawkTLDR, + }, + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`tl;?dr-prompt reset`), + HelpText: "Reset the tl;dr prompt", + Handler: p.resetTLDR, + }, + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`tl;?dr-prompt (?P.*)`), + HelpText: "Set the tl;dr prompt", + Handler: p.setTLDR, + }, + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`tl;?dr`), + HelpText: "Get a summary of the channel", + Handler: p.betterTLDR, + }, + { + Kind: bot.Message, IsCmd: false, + Regex: regexp.MustCompile(`.*`), + Handler: p.record, + }, + }) + p.b.Register(p, bot.Help, p.help) +} +func (p *TLDRPlugin) tldrCmd(r bot.Request) bool { + timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1)) + if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.") return true } + return false +} +func (p *TLDRPlugin) record(r bot.Request) bool { hist := history{ - body: lowercaseMessage, - user: message.User.Name, + body: strings.ToLower(r.Msg.Body), + user: r.Msg.User.Name, timestamp: time.Now(), } - p.addHistory(hist) + p.addHistory(r.Msg.Channel, hist) return false } -func (p *TLDRPlugin) addHistory(hist history) { - p.history = append(p.history, hist) - sz := len(p.history) - max := p.bot.Config().GetInt("TLDR.HistorySize", 1000) - keepHrs := time.Duration(p.bot.Config().GetInt("TLDR.KeepHours", 24)) +func (p *TLDRPlugin) oldTLDR(r bot.Request) bool { + p.lastRequest = time.Now() + nTopics := p.b.Config().GetInt("TLDR.Topics", 5) + + stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{}) + if len(stopWordSlice) == 0 { + stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR + p.b.Config().SetArray("TLDR.StopWords", stopWordSlice) + } + + vectoriser := nlp.NewCountVectoriser(stopWordSlice...) + lda := nlp.NewLatentDirichletAllocation(nTopics) + pipeline := nlp.NewPipeline(vectoriser, lda) + docsOverTopics, err := pipeline.FitTransform(p.getTopics()...) + + if err != nil { + log.Error().Err(err) + return false + } + + bestScores := make([][]float64, nTopics) + bestDocs := make([][]history, nTopics) + + supportingDocs := p.b.Config().GetInt("TLDR.Support", 3) + for i := 0; i < nTopics; i++ { + bestScores[i] = make([]float64, supportingDocs) + bestDocs[i] = make([]history, supportingDocs) + } + + dr, dc := docsOverTopics.Dims() + for topic := 0; topic < dr; topic++ { + minScore, minIndex := min(bestScores[topic]) + + for doc := 0; doc < dc; doc++ { + score := docsOverTopics.At(topic, doc) + if score > minScore { + bestScores[topic][minIndex] = score + bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc] + minScore, minIndex = min(bestScores[topic]) + } + } + } + + topicsOverWords := lda.Components() + tr, tc := topicsOverWords.Dims() + + vocab := make([]string, len(vectoriser.Vocabulary)) + for k, v := range vectoriser.Vocabulary { + vocab[v] = k + } + + response := "Here you go captain 'too good to read backlog':\n" + + for topic := 0; topic < tr; topic++ { + bestScore := -1. + bestTopic := "" + for word := 0; word < tc; word++ { + score := topicsOverWords.At(topic, word) + if score > bestScore { + bestScore = score + bestTopic = vocab[word] + } + } + response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic) + for i := range bestDocs[topic] { + response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body) + } + + } + + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response) + + return true +} + +func (p *TLDRPlugin) addHistory(ch string, hist history) { + p.history[ch] = append(p.history[ch], hist) + sz := len(p.history[ch]) + max := p.b.Config().GetInt("TLDR.HistorySize", 1000) + keepHrs := time.Duration(p.b.Config().GetInt("TLDR.KeepHours", 24)) // Clamp the size of the history if sz > max { - p.history = p.history[len(p.history)-max:] + p.history[ch] = p.history[ch][len(p.history)-max:] } // Remove old entries yesterday := time.Now().Add(-keepHrs * time.Hour) begin := 0 - for i, m := range p.history { + for i, m := range p.history[ch] { if !m.timestamp.Before(yesterday) { begin = i - 1 // should keep this message if begin < 0 { @@ -150,20 +206,22 @@ func (p *TLDRPlugin) addHistory(hist history) { break } } - p.history = p.history[begin:] + p.history[ch] = p.history[ch][begin:] } func (p *TLDRPlugin) getTopics() []string { hist := []string{} - for _, h := range p.history { - hist = append(hist, h.body) + for _, ch := range p.history { + for _, h := range ch { + hist = append(hist, h.body) + } } return hist } // Help responds to help requests. Every plugin must implement a help function. func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { - p.bot.Send(c, bot.Message, message.Channel, "tl;dr") + p.b.Send(c, bot.Message, message.Channel, "tl;dr") return true } @@ -178,3 +236,69 @@ func min(slice []float64) (float64, int) { } return minVal, minIndex } + +func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { + ch := r.Msg.Channel + c, err := p.getClient() + if err != nil { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client") + return true + } + promptConfig := p.c.Get(templateKey, defaultTemplate) + promptTpl := template.Must(template.New("gptprompt").Parse(promptConfig)) + prompt := bytes.Buffer{} + data := p.c.GetMap("tldr.promptdata", map[string]string{}) + promptTpl.Execute(&prompt, data) + backlog := "" + maxLen := p.c.GetInt("tldr.maxgpt", 4096) + for i := len(p.history[ch]) - 1; i >= 0; i-- { + h := p.history[ch][i] + str := fmt.Sprintf("%s: %s\n", h.user, h.body) + if len(backlog) > maxLen { + break + } + backlog = str + backlog + } + sess := c.NewChatSession(prompt.String()) + completion, err := sess.Complete(context.TODO(), backlog) + if err != nil { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request") + return true + } + log.Debug(). + Str("prompt", prompt.String()). + Str("backlog", backlog). + Str("completion", completion). + Msgf("tl;dr") + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion) + return true +} + +func (p *TLDRPlugin) squawkTLDR(r bot.Request) bool { + prompt := p.c.Get(templateKey, defaultTemplate) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Current prompt is: "%s"`, + strings.TrimSpace(prompt))) + return true +} + +func (p *TLDRPlugin) resetTLDR(r bot.Request) bool { + p.c.Set(templateKey, defaultTemplate) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Set prompt to: "%s"`, + strings.TrimSpace(defaultTemplate))) + return true +} + +func (p *TLDRPlugin) setTLDR(r bot.Request) bool { + prompt := r.Values["prompt"] + "\n" + p.c.Set(templateKey, prompt) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Set prompt to: "%s"`, strings.TrimSpace(prompt))) + return true +} + +func (p *TLDRPlugin) getClient() (*openai.Client, error) { + token := p.c.Get("gpt.token", "") + if token == "" { + return nil, fmt.Errorf("no GPT token given") + } + return openai.NewClient(token) +} diff --git a/plugins/tldr/tldr_test.go b/plugins/tldr/tldr_test.go index 4328463..fd9fa41 100644 --- a/plugins/tldr/tldr_test.go +++ b/plugins/tldr/tldr_test.go @@ -20,20 +20,27 @@ func init() { log.Logger = log.Logger.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } -func makeMessageBy(payload, by string) (bot.Connector, bot.Kind, msg.Message) { +var ch = "test" + +func makeMessageBy(payload, by string) bot.Request { isCmd := strings.HasPrefix(payload, "!") if isCmd { payload = payload[1:] } - return &cli.CliPlugin{}, bot.Message, msg.Message{ - User: &user.User{Name: by}, - Channel: "test", - Body: payload, - Command: isCmd, + + return bot.Request{ + Conn: &cli.CliPlugin{}, + Kind: bot.Message, + Msg: msg.Message{ + User: &user.User{Name: by}, + Channel: ch, + Body: payload, + Command: isCmd, + }, } } -func makeMessage(payload string) (bot.Connector, bot.Kind, msg.Message) { +func makeMessage(payload string) bot.Request { return makeMessageBy(payload, "tester") } @@ -43,51 +50,12 @@ func setup(t *testing.T) (*TLDRPlugin, *bot.MockBot) { return r, mb } -func Test(t *testing.T) { - c, mb := setup(t) - res := c.message(makeMessage("The quick brown fox jumped over the lazy dog")) - res = c.message(makeMessage("The cow jumped over the moon")) - res = c.message(makeMessage("The little dog laughed to see such fun")) - res = c.message(makeMessage("tl;dr")) - assert.True(t, res) - assert.Len(t, mb.Messages, 1) -} - -func TestDoubleUp(t *testing.T) { - c, mb := setup(t) - res := c.message(makeMessage("The quick brown fox jumped over the lazy dog")) - res = c.message(makeMessage("The cow jumped over the moon")) - res = c.message(makeMessage("The little dog laughed to see such fun")) - res = c.message(makeMessage("tl;dr")) - res = c.message(makeMessage("tl;dr")) - assert.True(t, res) - assert.Len(t, mb.Messages, 2) - assert.Contains(t, mb.Messages[1], "Slow down, cowboy.") -} - -func TestAddHistoryLimitsMessages(t *testing.T) { - c, _ := setup(t) - max := 1000 - c.bot.Config().Set("TLDR.HistorySize", strconv.Itoa(max)) - c.bot.Config().Set("TLDR.KeepHours", "24") - t0 := time.Now().Add(-24 * time.Hour) - for i := 0; i < max*2; i++ { - hist := history{ - body: "test", - user: "tester", - timestamp: t0.Add(time.Duration(i) * time.Second), - } - c.addHistory(hist) - } - assert.Len(t, c.history, max) -} - func TestAddHistoryLimitsDays(t *testing.T) { c, _ := setup(t) hrs := 24 expected := 24 - c.bot.Config().Set("TLDR.HistorySize", "100") - c.bot.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs)) + c.b.Config().Set("TLDR.HistorySize", "100") + c.b.Config().Set("TLDR.KeepHours", strconv.Itoa(hrs)) t0 := time.Now().Add(-time.Duration(hrs*2) * time.Hour) for i := 0; i < 48; i++ { hist := history{ @@ -95,7 +63,7 @@ func TestAddHistoryLimitsDays(t *testing.T) { user: "tester", timestamp: t0.Add(time.Duration(i) * time.Hour), } - c.addHistory(hist) + c.addHistory(ch, hist) } - assert.Len(t, c.history, expected, "%d != %d", len(c.history), expected) + assert.Len(t, c.history[ch], expected, "%d != %d", len(c.history), expected) }