From 40b9507d3ef021ede68291284aa26179ce40d091 Mon Sep 17 00:00:00 2001 From: Chris Sexton <3216719+chrissexton@users.noreply.github.com> Date: Fri, 27 Sep 2024 13:24:04 -0400 Subject: [PATCH] tldr: use Gemini --- plugins/tldr/tldr.go | 158 ++++++++++--------------------------------- 1 file changed, 34 insertions(+), 124 deletions(-) diff --git a/plugins/tldr/tldr.go b/plugins/tldr/tldr.go index 60dc036..9d120cb 100644 --- a/plugins/tldr/tldr.go +++ b/plugins/tldr/tldr.go @@ -3,9 +3,11 @@ package tldr import ( "bytes" "context" + "errors" "fmt" - "github.com/andrewstuart/openai" + "github.com/google/generative-ai-go/genai" "github.com/velour/catbase/config" + "google.golang.org/api/option" "regexp" "strings" "text/template" @@ -15,8 +17,6 @@ import ( "github.com/velour/catbase/bot/msg" "github.com/rs/zerolog/log" - - "github.com/james-bowman/nlp" ) const templateKey = "tldr.prompttemplate" @@ -51,12 +51,6 @@ func New(b bot.Bot) *TLDRPlugin { func (p *TLDRPlugin) register() { p.b.RegisterTable(p, bot.HandlerTable{ - { - Kind: bot.Message, IsCmd: true, - Regex: regexp.MustCompile(`old tl;dr`), - HelpText: "Get a rather inaccurate summary of the channel", - Handler: p.tldrCmd, - }, { Kind: bot.Message, IsCmd: true, Regex: regexp.MustCompile(`tl;?dr-prompt$`), @@ -90,15 +84,6 @@ func (p *TLDRPlugin) register() { p.b.Register(p, bot.Help, p.help) } -func (p *TLDRPlugin) tldrCmd(r bot.Request) bool { - timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1)) - if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) { - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.") - return true - } - return false -} - func (p *TLDRPlugin) record(r bot.Request) bool { hist := history{ body: strings.ToLower(r.Msg.Body), @@ -110,81 +95,6 @@ func (p *TLDRPlugin) record(r bot.Request) bool { return false } -func (p *TLDRPlugin) oldTLDR(r bot.Request) bool { - p.lastRequest = time.Now() - nTopics := p.b.Config().GetInt("TLDR.Topics", 5) - - stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{}) - if len(stopWordSlice) == 0 { - stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR - p.b.Config().SetArray("TLDR.StopWords", stopWordSlice) - } - - vectoriser := nlp.NewCountVectoriser(stopWordSlice...) - lda := nlp.NewLatentDirichletAllocation(nTopics) - pipeline := nlp.NewPipeline(vectoriser, lda) - docsOverTopics, err := pipeline.FitTransform(p.getTopics()...) - - if err != nil { - log.Error().Err(err) - return false - } - - bestScores := make([][]float64, nTopics) - bestDocs := make([][]history, nTopics) - - supportingDocs := p.b.Config().GetInt("TLDR.Support", 3) - for i := 0; i < nTopics; i++ { - bestScores[i] = make([]float64, supportingDocs) - bestDocs[i] = make([]history, supportingDocs) - } - - dr, dc := docsOverTopics.Dims() - for topic := 0; topic < dr; topic++ { - minScore, minIndex := min(bestScores[topic]) - - for doc := 0; doc < dc; doc++ { - score := docsOverTopics.At(topic, doc) - if score > minScore { - bestScores[topic][minIndex] = score - bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc] - minScore, minIndex = min(bestScores[topic]) - } - } - } - - topicsOverWords := lda.Components() - tr, tc := topicsOverWords.Dims() - - vocab := make([]string, len(vectoriser.Vocabulary)) - for k, v := range vectoriser.Vocabulary { - vocab[v] = k - } - - response := "Here you go captain 'too good to read backlog':\n" - - for topic := 0; topic < tr; topic++ { - bestScore := -1. - bestTopic := "" - for word := 0; word < tc; word++ { - score := topicsOverWords.At(topic, word) - if score > bestScore { - bestScore = score - bestTopic = vocab[word] - } - } - response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic) - for i := range bestDocs[topic] { - response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body) - } - - } - - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response) - - return true -} - func (p *TLDRPlugin) addHistory(ch string, hist history) { p.history[ch] = append(p.history[ch], hist) sz := len(p.history[ch]) @@ -209,34 +119,12 @@ func (p *TLDRPlugin) addHistory(ch string, hist history) { p.history[ch] = p.history[ch][begin:] } -func (p *TLDRPlugin) getTopics() []string { - hist := []string{} - for _, ch := range p.history { - for _, h := range ch { - hist = append(hist, h.body) - } - } - return hist -} - // Help responds to help requests. Every plugin must implement a help function. func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool { p.b.Send(c, bot.Message, message.Channel, "tl;dr") return true } -func min(slice []float64) (float64, int) { - minVal := 1. - minIndex := -1 - for index, val := range slice { - if val < minVal { - minVal = val - minIndex = index - } - } - return minVal, minIndex -} - func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { ch := r.Msg.Channel c, err := p.getClient() @@ -259,17 +147,34 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { } backlog = str + backlog } - sess := c.NewChatSession(prompt.String()) - completion, err := sess.Complete(context.TODO(), backlog) + + model := c.GenerativeModel("gemini-1.5-flash") + model.SystemInstruction = &genai.Content{ + Parts: []genai.Part{genai.Text(prompt.String())}, + } + res, err := model.GenerateContent(context.Background(), genai.Text(backlog)) if err != nil { - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request") + log.Error().Err(err).Send() + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: could not generate a TLDR") return true } log.Debug(). Str("prompt", prompt.String()). Str("backlog", backlog). - Str("completion", completion). + Interface("completion", res.Candidates). Msgf("tl;dr") + + if len(res.Candidates) == 0 { + log.Error().Err(errors.New("no candidates found")).Send() + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: no candidates generating a TLDR") + return true + } + + completion := "" + for _, p := range res.Candidates[0].Content.Parts { + completion += fmt.Sprintf("%s", p) + } + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion) return true } @@ -295,10 +200,15 @@ func (p *TLDRPlugin) setTLDR(r bot.Request) bool { return true } -func (p *TLDRPlugin) getClient() (*openai.Client, error) { - token := p.c.Get("gpt.token", "") - if token == "" { - return nil, fmt.Errorf("no GPT token given") +func (p *TLDRPlugin) getClient() (*genai.Client, error) { + ctx := context.Background() + key := p.c.Get("GEMINI_API_KEY", "") + if key == "" { + return nil, errors.New("missing GEMINI_API_KEY") } - return openai.NewClient(token) + client, err := genai.NewClient(ctx, option.WithAPIKey(key)) + if err != nil { + return nil, err + } + return client, nil }