From 624258a794aa3c19403cc07811a668c82da1bcda Mon Sep 17 00:00:00 2001 From: Chris Sexton <3216719+chrissexton@users.noreply.github.com> Date: Sat, 28 Sep 2024 10:28:58 -0400 Subject: [PATCH] talklikeapirate: we just went nucular --- connectors/discord/discord.go | 12 +++- main.go | 5 +- plugins/llm/gemini.go | 2 +- plugins/talklikeapirate/pirate.go | 108 ++++++++++++++++++++++++++++++ plugins/tldr/tldr.go | 4 +- 5 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 plugins/talklikeapirate/pirate.go diff --git a/connectors/discord/discord.go b/connectors/discord/discord.go index fac9878..4de2d62 100644 --- a/connectors/discord/discord.go +++ b/connectors/discord/discord.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "github.com/velour/catbase/plugins/talklikeapirate" "net/http" "strconv" "strings" @@ -33,6 +34,8 @@ type Discord struct { cmdHandlers map[string]CmdHandler guildID string + + Pirate *talklikeapirate.TalkLikeAPiratePlugin } func New(config *config.Config) *Discord { @@ -112,6 +115,14 @@ func (d Discord) Send(kind bot.Kind, args ...any) (string, error) { } func (d *Discord) sendMessage(channel, message string, meMessage bool, args ...any) (string, error) { + var err error + if d.Pirate != nil { + message, err = d.Pirate.Filter(message) + if err != nil { + log.Error().Err(err).Msg("could not pirate message") + } + } + if meMessage && !strings.HasPrefix(message, "_") && !strings.HasSuffix(message, "_") { message = "_" + message + "_" } @@ -167,7 +178,6 @@ func (d *Discord) sendMessage(channel, message string, meMessage bool, args ...a maxLen := 2000 chunkSize := maxLen - 100 var st *discordgo.Message - var err error if len(data.Content) > maxLen { tmp := data.Content data.Content = tmp[:chunkSize] diff --git a/main.go b/main.go index 545846f..3f6dac6 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ package main import ( "flag" "github.com/velour/catbase/plugins" + "github.com/velour/catbase/plugins/talklikeapirate" "io" "math/rand" "os" @@ -71,7 +72,9 @@ func main() { case "slackapp": client = slackapp.New(c) case "discord": - client = discord.New(c) + d := discord.New(c) + d.Pirate = talklikeapirate.New(c) + client = d default: log.Fatal().Msgf("Unknown connection type: %s", c.Get("type", "UNSET")) } diff --git a/plugins/llm/gemini.go b/plugins/llm/gemini.go index 1974ae2..578aa32 100644 --- a/plugins/llm/gemini.go +++ b/plugins/llm/gemini.go @@ -25,7 +25,7 @@ func (p *LLMPlugin) geminiConnect() error { } func (p *LLMPlugin) gemini(msg string) (chatEntry, error) { - model := p.geminiClient.GenerativeModel("gemini-1.5-flash") + model := p.geminiClient.GenerativeModel(p.c.Get("gemini.model", "gemini-1.5-flash")) model.SetMaxOutputTokens(int32(p.c.GetInt("gemini.maxtokens", 100))) model.SetTopP(float32(p.c.GetFloat64("gemini.topp", 0.95))) model.SetTopK(int32(p.c.GetInt("gemini.topk", 20))) diff --git a/plugins/talklikeapirate/pirate.go b/plugins/talklikeapirate/pirate.go new file mode 100644 index 0000000..3a8ef31 --- /dev/null +++ b/plugins/talklikeapirate/pirate.go @@ -0,0 +1,108 @@ +package talklikeapirate + +import ( + "context" + "errors" + "fmt" + "github.com/google/generative-ai-go/genai" + "github.com/rs/zerolog/log" + "github.com/velour/catbase/bot" + "github.com/velour/catbase/config" + "google.golang.org/api/option" +) + +// TalkLikeAPiratePlugin reimplements the send function +// with an AI intermediate. +type TalkLikeAPiratePlugin struct { + client *genai.Client + prompt string + + b bot.Bot + c *config.Config +} + +func New(c *config.Config) *TalkLikeAPiratePlugin { + p := &TalkLikeAPiratePlugin{ + c: c, + } + + return p +} + +func (p *TalkLikeAPiratePlugin) Filter(input string) (string, error) { + if !p.c.GetBool("talklikeapirate.enabled", false) { + return input, nil + } + if p.client == nil { + var err error + p.client, err = p.getClient() + if err != nil { + return input, err + } + } + model, err := p.GetModel() + if err != nil { + log.Error().Err(err).Send() + return input, err + } + + res, err := model.GenerateContent(context.Background(), genai.Text(input)) + if err != nil { + log.Error().Err(err).Send() + return input, err + } + + if len(res.Candidates) == 0 { + err := errors.New("no candidates found") + log.Error().Err(err).Send() + return input, err + } + + // Need to check here that we got an actual completion, not a + // warning about bad content. FinishReason exists on Completion. + + completion := "" + for _, p := range res.Candidates[0].Content.Parts { + completion += fmt.Sprintf("%s", p) + } + + return completion, nil +} + +func (p *TalkLikeAPiratePlugin) GetModel() (*genai.GenerativeModel, error) { + model := p.client.GenerativeModel(p.c.Get("gemini.model", "gemini-1.5-flash")) + model.SetMaxOutputTokens(int32(p.c.GetInt("gemini.maxtokens", 100))) + model.SetTopP(float32(p.c.GetFloat64("gemini.topp", 0.95))) + model.SetTopK(int32(p.c.GetInt("gemini.topk", 20))) + model.SetTemperature(float32(p.c.GetFloat64("gemini.temp", 0.9))) + + model.SafetySettings = []*genai.SafetySetting{ + {genai.HarmCategoryHarassment, genai.HarmBlockNone}, + {genai.HarmCategoryHateSpeech, genai.HarmBlockNone}, + {genai.HarmCategorySexuallyExplicit, genai.HarmBlockNone}, + {genai.HarmCategoryDangerousContent, genai.HarmBlockNone}, + } + + if prompt := p.c.Get("talklikeapirate.systemprompt", ""); prompt != "" { + model.SystemInstruction = &genai.Content{ + Parts: []genai.Part{genai.Text(prompt)}, + } + } else { + return nil, errors.New("no system prompt selected") + } + + return model, nil +} + +func (p *TalkLikeAPiratePlugin) getClient() (*genai.Client, error) { + ctx := context.Background() + key := p.c.Get("GEMINI_API_KEY", "") + if key == "" { + return nil, errors.New("missing GEMINI_API_KEY") + } + client, err := genai.NewClient(ctx, option.WithAPIKey(key)) + if err != nil { + return nil, err + } + return client, nil +} diff --git a/plugins/tldr/tldr.go b/plugins/tldr/tldr.go index 9d120cb..f9c6a29 100644 --- a/plugins/tldr/tldr.go +++ b/plugins/tldr/tldr.go @@ -129,7 +129,7 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { ch := r.Msg.Channel c, err := p.getClient() if err != nil { - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an OpenAI client") + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Couldn't fetch an AI client") return true } promptConfig := p.c.Get(templateKey, defaultTemplate) @@ -148,7 +148,7 @@ func (p *TLDRPlugin) betterTLDR(r bot.Request) bool { backlog = str + backlog } - model := c.GenerativeModel("gemini-1.5-flash") + model := c.GenerativeModel(p.c.Get("gemini.model", "gemini-1.5-flash")) model.SystemInstruction = &genai.Content{ Parts: []genai.Part{genai.Text(prompt.String())}, }