From c20df2d6598229f3977fa7d3847feed950fd40ce Mon Sep 17 00:00:00 2001 From: Chris Sexton <3216719+chrissexton@users.noreply.github.com> Date: Sat, 28 Sep 2024 10:52:49 -0400 Subject: [PATCH] talklikeapirate: add controls --- config/config.go | 4 + connectors/discord/discord.go | 2 +- main.go | 2 +- plugins/plugins.go | 2 + .../talklikeapirate/{pirate.go => filter.go} | 14 ++-- plugins/talklikeapirate/plugin.go | 84 +++++++++++++++++++ 6 files changed, 99 insertions(+), 9 deletions(-) rename plugins/talklikeapirate/{pirate.go => filter.go} (86%) create mode 100644 plugins/talklikeapirate/plugin.go diff --git a/config/config.go b/config/config.go index 40b4d0a..54941d1 100644 --- a/config/config.go +++ b/config/config.go @@ -179,6 +179,10 @@ func (c *Config) Set(key, value string) error { return nil } +func (c *Config) SetBool(key string, value bool) error { + return c.Set(key, fmt.Sprintf("%v", value)) +} + func (c *Config) RefreshSecrets() error { q := `select key, value from secrets` var secrets []Secret diff --git a/connectors/discord/discord.go b/connectors/discord/discord.go index 4de2d62..2c75175 100644 --- a/connectors/discord/discord.go +++ b/connectors/discord/discord.go @@ -35,7 +35,7 @@ type Discord struct { guildID string - Pirate *talklikeapirate.TalkLikeAPiratePlugin + Pirate *talklikeapirate.TalkLikeAPirateFilter } func New(config *config.Config) *Discord { diff --git a/main.go b/main.go index 3f6dac6..2409316 100644 --- a/main.go +++ b/main.go @@ -73,7 +73,7 @@ func main() { client = slackapp.New(c) case "discord": d := discord.New(c) - d.Pirate = talklikeapirate.New(c) + d.Pirate = talklikeapirate.NewFilter(c) client = d default: log.Fatal().Msgf("Unknown connection type: %s", c.Get("type", "UNSET")) diff --git a/plugins/plugins.go b/plugins/plugins.go index ae9f9bb..8f46560 100644 --- a/plugins/plugins.go +++ b/plugins/plugins.go @@ -45,6 +45,7 @@ import ( "github.com/velour/catbase/plugins/sms" "github.com/velour/catbase/plugins/stock" "github.com/velour/catbase/plugins/talker" + "github.com/velour/catbase/plugins/talklikeapirate" "github.com/velour/catbase/plugins/tappd" "github.com/velour/catbase/plugins/tell" "github.com/velour/catbase/plugins/tldr" @@ -102,6 +103,7 @@ func Register(b bot.Bot) { b.AddPlugin(talker.New(b)) b.AddPlugin(fact.New(b)) b.AddPlugin(llm.New(b)) + b.AddPlugin(talklikeapirate.New(b)) // catches anything left, will always return true b.AddPlugin(deadend.New(b)) } diff --git a/plugins/talklikeapirate/pirate.go b/plugins/talklikeapirate/filter.go similarity index 86% rename from plugins/talklikeapirate/pirate.go rename to plugins/talklikeapirate/filter.go index 3a8ef31..f7d3d17 100644 --- a/plugins/talklikeapirate/pirate.go +++ b/plugins/talklikeapirate/filter.go @@ -11,9 +11,9 @@ import ( "google.golang.org/api/option" ) -// TalkLikeAPiratePlugin reimplements the send function +// TalkLikeAPirateFilter reimplements the send function // with an AI intermediate. -type TalkLikeAPiratePlugin struct { +type TalkLikeAPirateFilter struct { client *genai.Client prompt string @@ -21,15 +21,15 @@ type TalkLikeAPiratePlugin struct { c *config.Config } -func New(c *config.Config) *TalkLikeAPiratePlugin { - p := &TalkLikeAPiratePlugin{ +func NewFilter(c *config.Config) *TalkLikeAPirateFilter { + p := &TalkLikeAPirateFilter{ c: c, } return p } -func (p *TalkLikeAPiratePlugin) Filter(input string) (string, error) { +func (p *TalkLikeAPirateFilter) Filter(input string) (string, error) { if !p.c.GetBool("talklikeapirate.enabled", false) { return input, nil } @@ -69,7 +69,7 @@ func (p *TalkLikeAPiratePlugin) Filter(input string) (string, error) { return completion, nil } -func (p *TalkLikeAPiratePlugin) GetModel() (*genai.GenerativeModel, error) { +func (p *TalkLikeAPirateFilter) GetModel() (*genai.GenerativeModel, error) { model := p.client.GenerativeModel(p.c.Get("gemini.model", "gemini-1.5-flash")) model.SetMaxOutputTokens(int32(p.c.GetInt("gemini.maxtokens", 100))) model.SetTopP(float32(p.c.GetFloat64("gemini.topp", 0.95))) @@ -94,7 +94,7 @@ func (p *TalkLikeAPiratePlugin) GetModel() (*genai.GenerativeModel, error) { return model, nil } -func (p *TalkLikeAPiratePlugin) getClient() (*genai.Client, error) { +func (p *TalkLikeAPirateFilter) getClient() (*genai.Client, error) { ctx := context.Background() key := p.c.Get("GEMINI_API_KEY", "") if key == "" { diff --git a/plugins/talklikeapirate/plugin.go b/plugins/talklikeapirate/plugin.go new file mode 100644 index 0000000..5dcd352 --- /dev/null +++ b/plugins/talklikeapirate/plugin.go @@ -0,0 +1,84 @@ +package talklikeapirate + +import ( + "fmt" + "github.com/velour/catbase/bot" + "github.com/velour/catbase/config" + "regexp" + "strings" +) + +// TalkLikeAPiratePlugin allows admin of the filter +type TalkLikeAPiratePlugin struct { + b bot.Bot + c *config.Config + handlers bot.HandlerTable +} + +func New(b bot.Bot) *TalkLikeAPiratePlugin { + p := &TalkLikeAPiratePlugin{ + b: b, + c: b.Config(), + } + + p.register() + + return p +} + +func (p *TalkLikeAPiratePlugin) register() { + p.handlers = bot.HandlerTable{ + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`^enable pirate$`), + HelpText: "Enable message filter", + Handler: p.setEnabled(true), + }, + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`^disable pirate$`), + HelpText: "Disable message filter", + Handler: p.setEnabled(false), + }, + { + Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile(`^pirate-prompt:? (?P.*)$`), + HelpText: "Set message filter prompt", + Handler: p.setPrompt, + }, + { + Kind: bot.Help, IsCmd: false, + Regex: regexp.MustCompile(`.*`), + Handler: p.help, + }, + } + p.b.RegisterTable(p, p.handlers) +} + +func (p *TalkLikeAPiratePlugin) setEnabled(isEnabled bool) bot.ResponseHandler { + return func(r bot.Request) bool { + p.c.SetBool("talklikeapirate.enabled", isEnabled) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("I just set the message filter status to: %v", isEnabled)) + return true + } +} + +func (p *TalkLikeAPiratePlugin) setPrompt(r bot.Request) bool { + prompt := r.Values["text"] + p.c.Set("talklikeapirate.systemprompt", prompt) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("I set the message filter prompt to: %s", prompt)) + return true +} + +func (p *TalkLikeAPiratePlugin) help(r bot.Request) bool { + out := "Talk like a pirate commands:\n" + for _, h := range p.handlers { + if h.HelpText == "" { + continue + } + out += fmt.Sprintf("```%s```\t%s", h.Regex.String(), h.HelpText) + } + out = strings.TrimSpace(out) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, out) + return true +}