mirror of https://github.com/velour/catbase.git
tldr: use Gemini
This commit is contained in:
parent
24290b27d1
commit
40b9507d3e
|
@ -3,9 +3,11 @@ package tldr
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/andrewstuart/openai"
|
"github.com/google/generative-ai-go/genai"
|
||||||
"github.com/velour/catbase/config"
|
"github.com/velour/catbase/config"
|
||||||
|
"google.golang.org/api/option"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
@ -15,8 +17,6 @@ import (
|
||||||
"github.com/velour/catbase/bot/msg"
|
"github.com/velour/catbase/bot/msg"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
"github.com/james-bowman/nlp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const templateKey = "tldr.prompttemplate"
|
const templateKey = "tldr.prompttemplate"
|
||||||
|
@ -51,12 +51,6 @@ func New(b bot.Bot) *TLDRPlugin {
|
||||||
|
|
||||||
func (p *TLDRPlugin) register() {
|
func (p *TLDRPlugin) register() {
|
||||||
p.b.RegisterTable(p, bot.HandlerTable{
|
p.b.RegisterTable(p, bot.HandlerTable{
|
||||||
{
|
|
||||||
Kind: bot.Message, IsCmd: true,
|
|
||||||
Regex: regexp.MustCompile(`old tl;dr`),
|
|
||||||
HelpText: "Get a rather inaccurate summary of the channel",
|
|
||||||
Handler: p.tldrCmd,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Kind: bot.Message, IsCmd: true,
|
Kind: bot.Message, IsCmd: true,
|
||||||
Regex: regexp.MustCompile(`tl;?dr-prompt$`),
|
Regex: regexp.MustCompile(`tl;?dr-prompt$`),
|
||||||
|
@ -90,15 +84,6 @@ func (p *TLDRPlugin) register() {
|
||||||
p.b.Register(p, bot.Help, p.help)
|
p.b.Register(p, bot.Help, p.help)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) tldrCmd(r bot.Request) bool {
|
|
||||||
timeLimit := time.Duration(p.b.Config().GetInt("TLDR.HourLimit", 1))
|
|
||||||
if p.lastRequest.After(time.Now().Add(-timeLimit * time.Hour)) {
|
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Slow down, cowboy. Read that tiny backlog.")
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TLDRPlugin) record(r bot.Request) bool {
|
func (p *TLDRPlugin) record(r bot.Request) bool {
|
||||||
hist := history{
|
hist := history{
|
||||||
body: strings.ToLower(r.Msg.Body),
|
body: strings.ToLower(r.Msg.Body),
|
||||||
|
@ -110,81 +95,6 @@ func (p *TLDRPlugin) record(r bot.Request) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) oldTLDR(r bot.Request) bool {
|
|
||||||
p.lastRequest = time.Now()
|
|
||||||
nTopics := p.b.Config().GetInt("TLDR.Topics", 5)
|
|
||||||
|
|
||||||
stopWordSlice := p.b.Config().GetArray("TLDR.StopWords", []string{})
|
|
||||||
if len(stopWordSlice) == 0 {
|
|
||||||
stopWordSlice = THESE_ARE_NOT_THE_WORDS_YOU_ARE_LOOKING_FOR
|
|
||||||
p.b.Config().SetArray("TLDR.StopWords", stopWordSlice)
|
|
||||||
}
|
|
||||||
|
|
||||||
vectoriser := nlp.NewCountVectoriser(stopWordSlice...)
|
|
||||||
lda := nlp.NewLatentDirichletAllocation(nTopics)
|
|
||||||
pipeline := nlp.NewPipeline(vectoriser, lda)
|
|
||||||
docsOverTopics, err := pipeline.FitTransform(p.getTopics()...)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bestScores := make([][]float64, nTopics)
|
|
||||||
bestDocs := make([][]history, nTopics)
|
|
||||||
|
|
||||||
supportingDocs := p.b.Config().GetInt("TLDR.Support", 3)
|
|
||||||
for i := 0; i < nTopics; i++ {
|
|
||||||
bestScores[i] = make([]float64, supportingDocs)
|
|
||||||
bestDocs[i] = make([]history, supportingDocs)
|
|
||||||
}
|
|
||||||
|
|
||||||
dr, dc := docsOverTopics.Dims()
|
|
||||||
for topic := 0; topic < dr; topic++ {
|
|
||||||
minScore, minIndex := min(bestScores[topic])
|
|
||||||
|
|
||||||
for doc := 0; doc < dc; doc++ {
|
|
||||||
score := docsOverTopics.At(topic, doc)
|
|
||||||
if score > minScore {
|
|
||||||
bestScores[topic][minIndex] = score
|
|
||||||
bestDocs[topic][minIndex] = p.history[r.Msg.Channel][doc]
|
|
||||||
minScore, minIndex = min(bestScores[topic])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
topicsOverWords := lda.Components()
|
|
||||||
tr, tc := topicsOverWords.Dims()
|
|
||||||
|
|
||||||
vocab := make([]string, len(vectoriser.Vocabulary))
|
|
||||||
for k, v := range vectoriser.Vocabulary {
|
|
||||||
vocab[v] = k
|
|
||||||
}
|
|
||||||
|
|
||||||
response := "Here you go captain 'too good to read backlog':\n"
|
|
||||||
|
|
||||||
for topic := 0; topic < tr; topic++ {
|
|
||||||
bestScore := -1.
|
|
||||||
bestTopic := ""
|
|
||||||
for word := 0; word < tc; word++ {
|
|
||||||
score := topicsOverWords.At(topic, word)
|
|
||||||
if score > bestScore {
|
|
||||||
bestScore = score
|
|
||||||
bestTopic = vocab[word]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response += fmt.Sprintf("\n*Topic #%d: %s*\n", topic, bestTopic)
|
|
||||||
for i := range bestDocs[topic] {
|
|
||||||
response += fmt.Sprintf("<%s>%s\n", bestDocs[topic][i].user, bestDocs[topic][i].body)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, response)
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TLDRPlugin) addHistory(ch string, hist history) {
|
func (p *TLDRPlugin) addHistory(ch string, hist history) {
|
||||||
p.history[ch] = append(p.history[ch], hist)
|
p.history[ch] = append(p.history[ch], hist)
|
||||||
sz := len(p.history[ch])
|
sz := len(p.history[ch])
|
||||||
|
@ -209,34 +119,12 @@ func (p *TLDRPlugin) addHistory(ch string, hist history) {
|
||||||
p.history[ch] = p.history[ch][begin:]
|
p.history[ch] = p.history[ch][begin:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) getTopics() []string {
|
|
||||||
hist := []string{}
|
|
||||||
for _, ch := range p.history {
|
|
||||||
for _, h := range ch {
|
|
||||||
hist = append(hist, h.body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hist
|
|
||||||
}
|
|
||||||
|
|
||||||
// Help responds to help requests. Every plugin must implement a help function.
|
// Help responds to help requests. Every plugin must implement a help function.
|
||||||
func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
|
func (p *TLDRPlugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...any) bool {
|
||||||
p.b.Send(c, bot.Message, message.Channel, "tl;dr")
|
p.b.Send(c, bot.Message, message.Channel, "tl;dr")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(slice []float64) (float64, int) {
|
|
||||||
minVal := 1.
|
|
||||||
minIndex := -1
|
|
||||||
for index, val := range slice {
|
|
||||||
if val < minVal {
|
|
||||||
minVal = val
|
|
||||||
minIndex = index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return minVal, minIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
||||||
ch := r.Msg.Channel
|
ch := r.Msg.Channel
|
||||||
c, err := p.getClient()
|
c, err := p.getClient()
|
||||||
|
@ -259,17 +147,34 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool {
|
||||||
}
|
}
|
||||||
backlog = str + backlog
|
backlog = str + backlog
|
||||||
}
|
}
|
||||||
sess := c.NewChatSession(prompt.String())
|
|
||||||
completion, err := sess.Complete(context.TODO(), backlog)
|
model := c.GenerativeModel("gemini-1.5-flash")
|
||||||
|
model.SystemInstruction = &genai.Content{
|
||||||
|
Parts: []genai.Part{genai.Text(prompt.String())},
|
||||||
|
}
|
||||||
|
res, err := model.GenerateContent(context.Background(), genai.Text(backlog))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't run the OpenAI request")
|
log.Error().Err(err).Send()
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: could not generate a TLDR")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("prompt", prompt.String()).
|
Str("prompt", prompt.String()).
|
||||||
Str("backlog", backlog).
|
Str("backlog", backlog).
|
||||||
Str("completion", completion).
|
Interface("completion", res.Candidates).
|
||||||
Msgf("tl;dr")
|
Msgf("tl;dr")
|
||||||
|
|
||||||
|
if len(res.Candidates) == 0 {
|
||||||
|
log.Error().Err(errors.New("no candidates found")).Send()
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: no candidates generating a TLDR")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
completion := ""
|
||||||
|
for _, p := range res.Candidates[0].Content.Parts {
|
||||||
|
completion += fmt.Sprintf("%s", p)
|
||||||
|
}
|
||||||
|
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion)
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, completion)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -295,10 +200,15 @@ func (p *TLDRPlugin) setTLDR(r bot.Request) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLDRPlugin) getClient() (*openai.Client, error) {
|
func (p *TLDRPlugin) getClient() (*genai.Client, error) {
|
||||||
token := p.c.Get("gpt.token", "")
|
ctx := context.Background()
|
||||||
if token == "" {
|
key := p.c.Get("GEMINI_API_KEY", "")
|
||||||
return nil, fmt.Errorf("no GPT token given")
|
if key == "" {
|
||||||
|
return nil, errors.New("missing GEMINI_API_KEY")
|
||||||
}
|
}
|
||||||
return openai.NewClient(token)
|
client, err := genai.NewClient(ctx, option.WithAPIKey(key))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue