mirror of https://github.com/velour/catbase.git
Compare commits
No commits in common. "d2c5489634eb3ec08b7dad9b1e92cd2858d71e2a" and "338293cf427f8ae811edfeac3f0a14ed72c0ceb6" have entirely different histories.
d2c5489634
...
338293cf42
2
main.go
2
main.go
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/velour/catbase/bot/msg"
|
"github.com/velour/catbase/bot/msg"
|
||||||
"github.com/velour/catbase/connectors/discord"
|
"github.com/velour/catbase/connectors/discord"
|
||||||
"github.com/velour/catbase/plugins/giphy"
|
"github.com/velour/catbase/plugins/giphy"
|
||||||
"github.com/velour/catbase/plugins/gpt3"
|
|
||||||
"github.com/velour/catbase/plugins/last"
|
"github.com/velour/catbase/plugins/last"
|
||||||
"github.com/velour/catbase/plugins/mayi"
|
"github.com/velour/catbase/plugins/mayi"
|
||||||
"github.com/velour/catbase/plugins/quotegame"
|
"github.com/velour/catbase/plugins/quotegame"
|
||||||
|
@ -163,7 +162,6 @@ func main() {
|
||||||
b.AddPlugin(achievements.New(b))
|
b.AddPlugin(achievements.New(b))
|
||||||
b.AddPlugin(sms.New(b))
|
b.AddPlugin(sms.New(b))
|
||||||
b.AddPlugin(countdown.New(b))
|
b.AddPlugin(countdown.New(b))
|
||||||
b.AddPlugin(gpt3.New(b))
|
|
||||||
b.AddPlugin(rest.New(b))
|
b.AddPlugin(rest.New(b))
|
||||||
b.AddPlugin(quotegame.New(b))
|
b.AddPlugin(quotegame.New(b))
|
||||||
// catches anything left, will always return true
|
// catches anything left, will always return true
|
||||||
|
|
|
@ -215,6 +215,8 @@ func (p *FirstPlugin) register() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Interface("message", r.Msg).Msg("first check")
|
||||||
|
|
||||||
first, err := getLastFirst(p.db, r.Msg.Channel)
|
first, err := getLastFirst(p.db, r.Msg.Channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|
|
@ -1,120 +0,0 @@
|
||||||
package gpt3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/velour/catbase/bot"
|
|
||||||
"github.com/velour/catbase/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
|
|
||||||
|
|
||||||
type GPT3Plugin struct {
|
|
||||||
b bot.Bot
|
|
||||||
c *config.Config
|
|
||||||
h bot.HandlerTable
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(b bot.Bot) *GPT3Plugin {
|
|
||||||
p := &GPT3Plugin{
|
|
||||||
b: b,
|
|
||||||
c: b.Config(),
|
|
||||||
}
|
|
||||||
p.register()
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *GPT3Plugin) register() {
|
|
||||||
p.h = bot.HandlerTable{
|
|
||||||
{
|
|
||||||
Kind: bot.Message, IsCmd: true,
|
|
||||||
Regex: regexp.MustCompile(`(?is)^gpt3 (?P<text>.*)`),
|
|
||||||
HelpText: "request text completion",
|
|
||||||
Handler: p.message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
log.Debug().Msg("Registering GPT3 handlers")
|
|
||||||
p.b.RegisterTable(p, p.h)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *GPT3Plugin) message(r bot.Request) bool {
|
|
||||||
stem := r.Values["text"]
|
|
||||||
log.Debug().Msgf("Got GPT3 request: %s", stem)
|
|
||||||
postStruct := gpt3Request{
|
|
||||||
Prompt: stem,
|
|
||||||
MaxTokens: p.c.GetInt("gpt3.tokens", 16),
|
|
||||||
Temperature: p.c.GetFloat64("gpt3.temperature", 1),
|
|
||||||
TopP: p.c.GetFloat64("gpt3.top_p", 1),
|
|
||||||
N: p.c.GetInt("gpt3.n", 1),
|
|
||||||
Stop: p.c.GetArray("gpt3.stop", []string{"\n"}),
|
|
||||||
Echo: true,
|
|
||||||
}
|
|
||||||
postBody, _ := json.Marshal(postStruct)
|
|
||||||
client := &http.Client{}
|
|
||||||
u := fmt.Sprintf(gpt3URL, p.c.Get("gpt3.engine", "ada"))
|
|
||||||
req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("could not make gpt3 request")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
gpt3Key := p.c.Get("gpt3.bearer", "")
|
|
||||||
if gpt3Key == "" {
|
|
||||||
log.Error().Msgf("no GPT3 key given")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
|
|
||||||
res, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("could not make gpt3 request")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
resBody, _ := ioutil.ReadAll(res.Body)
|
|
||||||
gpt3Resp := gpt3Response{}
|
|
||||||
err = json.Unmarshal(resBody, &gpt3Resp)
|
|
||||||
|
|
||||||
log.Debug().
|
|
||||||
Str("body", string(resBody)).
|
|
||||||
Interface("resp", gpt3Resp).
|
|
||||||
Msg("OpenAI Response")
|
|
||||||
|
|
||||||
msg := "OpenAI is too shitty to respond to that."
|
|
||||||
if len(gpt3Resp.Choices) > 0 {
|
|
||||||
msg = gpt3Resp.Choices[rand.Intn(len(gpt3Resp.Choices))].Text
|
|
||||||
}
|
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
type gpt3Request struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
Temperature float64 `json:"temperature"`
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Logprobs interface{} `json:"logprobs"`
|
|
||||||
Stop []string `json:"stop"`
|
|
||||||
Echo bool `json:"echo"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type gpt3Response struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
Logprobs interface{} `json:"logprobs"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
} `json:"choices"`
|
|
||||||
}
|
|
Loading…
Reference in New Issue