catbase/plugins/gpt2/gpt2.go

82 lines
1.8 KiB
Go
Raw Normal View History

2020-08-20 21:20:35 +00:00
package gpt2
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/bot/msg"
"github.com/velour/catbase/config"
)
type GPT2Plugin struct {
b bot.Bot
c *config.Config
}
func New(b bot.Bot) *GPT2Plugin {
p := &GPT2Plugin{
b: b,
c: b.Config(),
}
b.Register(p, bot.Message, p.message)
b.Register(p, bot.Help, p.help)
return p
}
2020-08-21 14:19:36 +00:00
const prefix = "gpt2"
2020-08-20 21:20:35 +00:00
func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
ch := message.Channel
lowerBody := strings.ToLower(message.Body)
2020-08-21 14:19:36 +00:00
if message.Command && strings.HasPrefix(lowerBody, prefix) {
input := message.Body[len(prefix)+1:]
2020-08-20 21:20:35 +00:00
p.b.Send(c, bot.Message, ch, p.getGPTText(input))
return true
}
return false
}
func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
txt := "Invoke the GPT-2 API with: `!gpt2 <your seed text>"
p.b.Send(c, bot.Message, message.Channel, txt)
return true
}
2020-08-21 14:19:36 +00:00
const separator = "<|endoftext|>"
2020-08-20 21:20:35 +00:00
func (p *GPT2Plugin) getGPTText(input string) string {
serviceURI := p.c.Get("gpt.service", "")
if serviceURI == "" {
return "Cannot contact GPT2 service."
}
values := url.Values{}
values.Add("text", input)
resp, err := http.PostForm(serviceURI, values)
if err != nil {
return fmt.Sprintf("Error retrieving GPT2 response: %s", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Sprintf("Error reading GPT2 response: %s", err)
}
resp.Body.Close()
2020-08-21 14:19:36 +00:00
txt := p.cleanup(input + string(body))
return txt
}
func (p *GPT2Plugin) cleanup(txt string) string {
txt = strings.Split(txt, separator)[0]
if !strings.HasSuffix(txt, ".") && strings.Count(txt, ".") > 1 {
idx := strings.LastIndexAny(txt, ".")
txt = txt[:idx+1]
}
2020-08-20 21:20:35 +00:00
txt = strings.TrimSpace(txt)
return txt
}