mirror of https://github.com/velour/catbase.git
gpt2: add plugin
This commit is contained in:
parent
d69d708245
commit
c6cb8053c0
3
main.go
3
main.go
|
@ -9,6 +9,8 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/velour/catbase/plugins/gpt2"
|
||||||
|
|
||||||
"github.com/velour/catbase/plugins/achievements"
|
"github.com/velour/catbase/plugins/achievements"
|
||||||
"github.com/velour/catbase/plugins/aoc"
|
"github.com/velour/catbase/plugins/aoc"
|
||||||
"github.com/velour/catbase/plugins/countdown"
|
"github.com/velour/catbase/plugins/countdown"
|
||||||
|
@ -148,6 +150,7 @@ func main() {
|
||||||
b.AddPlugin(achievements.New(b))
|
b.AddPlugin(achievements.New(b))
|
||||||
b.AddPlugin(sms.New(b))
|
b.AddPlugin(sms.New(b))
|
||||||
b.AddPlugin(countdown.New(b))
|
b.AddPlugin(countdown.New(b))
|
||||||
|
b.AddPlugin(gpt2.New(b))
|
||||||
// catches anything left, will always return true
|
// catches anything left, will always return true
|
||||||
b.AddPlugin(fact.New(b))
|
b.AddPlugin(fact.New(b))
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package gpt2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/velour/catbase/bot"
|
||||||
|
"github.com/velour/catbase/bot/msg"
|
||||||
|
"github.com/velour/catbase/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GPT2Plugin struct {
|
||||||
|
b bot.Bot
|
||||||
|
c *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(b bot.Bot) *GPT2Plugin {
|
||||||
|
p := &GPT2Plugin{
|
||||||
|
b: b,
|
||||||
|
c: b.Config(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Register(p, bot.Message, p.message)
|
||||||
|
b.Register(p, bot.Help, p.help)
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GPT2Plugin) message(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
|
||||||
|
ch := message.Channel
|
||||||
|
lowerBody := strings.ToLower(message.Body)
|
||||||
|
if message.Command && strings.HasPrefix(lowerBody, "gpt2") {
|
||||||
|
input := strings.TrimPrefix(lowerBody, "gpt2")
|
||||||
|
p.b.Send(c, bot.Message, ch, p.getGPTText(input))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GPT2Plugin) help(c bot.Connector, kind bot.Kind, message msg.Message, args ...interface{}) bool {
|
||||||
|
txt := "Invoke the GPT-2 API with: `!gpt2 <your seed text>"
|
||||||
|
p.b.Send(c, bot.Message, message.Channel, txt)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GPT2Plugin) getGPTText(input string) string {
|
||||||
|
serviceURI := p.c.Get("gpt.service", "")
|
||||||
|
if serviceURI == "" {
|
||||||
|
return "Cannot contact GPT2 service."
|
||||||
|
}
|
||||||
|
values := url.Values{}
|
||||||
|
values.Add("text", input)
|
||||||
|
resp, err := http.PostForm(serviceURI, values)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Error retrieving GPT2 response: %s", err)
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Error reading GPT2 response: %s", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
txt := string(body)
|
||||||
|
txt = strings.TrimSpace(txt)
|
||||||
|
return txt
|
||||||
|
}
|
Loading…
Reference in New Issue