From c6cb8053c0b4346828e0f7521e8ba60f1c6ab5ec Mon Sep 17 00:00:00 2001 From: Chris Sexton Date: Thu, 20 Aug 2020 17:20:35 -0400 Subject: [PATCH] gpt2: add plugin --- main.go | 3 ++ plugins/gpt2/gpt2.go | 68 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 plugins/gpt2/gpt2.go diff --git a/main.go b/main.go index 38570ed..c94f592 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,8 @@ import ( "os" "time" + "github.com/velour/catbase/plugins/gpt2" + "github.com/velour/catbase/plugins/achievements" "github.com/velour/catbase/plugins/aoc" "github.com/velour/catbase/plugins/countdown" @@ -148,6 +150,7 @@ func main() { b.AddPlugin(achievements.New(b)) b.AddPlugin(sms.New(b)) b.AddPlugin(countdown.New(b)) + b.AddPlugin(gpt2.New(b)) // catches anything left, will always return true b.AddPlugin(fact.New(b)) diff --git a/plugins/gpt2/gpt2.go b/plugins/gpt2/gpt2.go new file mode 100644 index 0000000..da03f69 --- /dev/null +++ b/plugins/gpt2/gpt2.go @@ -0,0 +1,68 @@ +package gpt2 + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/velour/catbase/bot" + "github.com/velour/catbase/bot/msg" + "github.com/velour/catbase/config" +) + +type GPT2Plugin struct { + b bot.Bot + c *config.Config +} + +func New(b bot.Bot) *GPT2Plugin { + p := &GPT2Plugin{ + b: b, + c: b.Config(), + } + + b.Register(p, bot.Message, p.message) + b.Register(p, bot.Help, p.help) + + return p +} + +func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool { + ch := message.Channel + lowerBody := strings.ToLower(message.Body) + if message.Command && strings.HasPrefix(lowerBody, "gpt2") { + input := strings.TrimPrefix(lowerBody, "gpt2") + p.b.Send(c, bot.Message, ch, p.getGPTText(input)) + return true + } + return false +} + +func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool { + txt := "Invoke the GPT-2 API with: `!gpt2 " + p.b.Send(c, bot.Message, message.Channel, txt) + return true +} + +func (p *GPT2Plugin) getGPTText(input string) string { + serviceURI := p.c.Get("gpt.service", "") + if serviceURI == "" { + return "Cannot contact GPT2 service." + } + values := url.Values{} + values.Add("text", input) + resp, err := http.PostForm(serviceURI, values) + if err != nil { + return fmt.Sprintf("Error retrieving GPT2 response: %s", err) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Sprintf("Error reading GPT2 response: %s", err) + } + resp.Body.Close() + txt := string(body) + txt = strings.TrimSpace(txt) + return txt +}