catbase/plugins/gpt/gpt3.go

252 lines
6.6 KiB
Go
Raw Normal View History

2023-03-03 16:37:52 +00:00
package gpt
2021-11-19 16:24:20 +00:00
import (
"bytes"
"encoding/json"
"fmt"
"io"
2021-11-19 16:24:20 +00:00
"math/rand"
"net/http"
"reflect"
2021-11-19 16:24:20 +00:00
"regexp"
"strings"
2021-11-19 16:24:20 +00:00
"github.com/rs/zerolog/log"
"github.com/velour/catbase/bot"
"github.com/velour/catbase/config"
)
2021-11-19 18:00:30 +00:00
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
const gpt3ModURL = "https://api.openai.com/v1/moderations"
2021-11-19 16:24:20 +00:00
2023-03-03 16:37:34 +00:00
type GPTPlugin struct {
2021-11-19 16:24:20 +00:00
b bot.Bot
c *config.Config
h bot.HandlerTable
chatCount int
2021-11-19 16:24:20 +00:00
}
2023-03-03 16:37:34 +00:00
func New(b bot.Bot) *GPTPlugin {
p := &GPTPlugin{
2021-11-19 16:24:20 +00:00
b: b,
c: b.Config(),
}
p.register()
return p
}
2023-03-03 16:37:34 +00:00
func (p *GPTPlugin) register() {
2021-11-19 16:24:20 +00:00
p.h = bot.HandlerTable{
{
Kind: bot.Message, IsCmd: true,
2021-11-19 18:00:30 +00:00
Regex: regexp.MustCompile(`(?is)^gpt3 (?P<text>.*)`),
2021-11-19 16:24:20 +00:00
HelpText: "request text completion",
Handler: p.message,
},
2023-03-03 16:37:34 +00:00
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt (?P<text>.*)`),
HelpText: "chat completion",
Handler: p.chatMessage,
},
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?is)^gpt-prompt: (?P<text>.*)`),
HelpText: "set the ChatGPT prompt",
Handler: p.setPromptMessage,
},
2023-03-05 20:26:41 +00:00
{
Kind: bot.Message, IsCmd: true,
Regex: regexp.MustCompile(`(?P<text>.*)`),
Handler: p.chatMessage,
},
2021-11-19 16:24:20 +00:00
}
log.Debug().Msg("Registering GPT3 handlers")
p.b.RegisterTable(p, p.h)
}
2023-03-03 16:37:34 +00:00
func (p *GPTPlugin) setPromptMessage(r bot.Request) bool {
prompt := r.Values["text"]
if err := p.setPrompt(prompt); err != nil {
resp := fmt.Sprintf("Error: %s", err)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, resp)
}
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf(`Okay. I set the prompt to: "%s"`, prompt))
return true
}
func (p *GPTPlugin) chatMessage(r bot.Request) bool {
resp, err := p.chatGPT(r.Values["text"])
if err != nil {
resp = fmt.Sprintf("Error: %s", err)
}
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, resp)
return true
}
func (p *GPTPlugin) message(r bot.Request) bool {
2021-11-19 16:24:20 +00:00
stem := r.Values["text"]
2023-03-03 16:37:34 +00:00
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, p.gpt3(stem))
return true
}
2023-03-03 16:37:34 +00:00
func (p *GPTPlugin) gpt3(stem string) string {
2021-11-19 16:24:20 +00:00
log.Debug().Msgf("Got GPT3 request: %s", stem)
if err := p.checkStem(stem); err != nil {
2023-02-07 15:21:43 +00:00
return "GPT3 moderation " + err.Error()
}
2021-11-19 16:24:20 +00:00
postStruct := gpt3Request{
Prompt: stem,
MaxTokens: p.c.GetInt("gpt3.tokens", 16),
Temperature: p.c.GetFloat64("gpt3.temperature", 1),
TopP: p.c.GetFloat64("gpt3.top_p", 1),
N: p.c.GetInt("gpt3.n", 1),
2021-11-19 18:00:30 +00:00
Stop: p.c.GetArray("gpt3.stop", []string{"\n"}),
2023-03-03 20:08:54 +00:00
Echo: p.c.GetBool("gpt3.echo", false),
2021-11-19 16:24:20 +00:00
}
2023-03-03 16:37:34 +00:00
val, err := p.mkRequest(gpt3URL, postStruct)
if err != nil {
return err.Error()
}
choices := val.(gpt3Response).Choices
if len(choices) > 0 {
return choices[rand.Intn(len(choices))].Text
}
return "OpenAI is too shitty to respond to that."
}
func (p *GPTPlugin) mkRequest(endPoint string, postStruct interface{}) (interface{}, error) {
2021-11-19 16:24:20 +00:00
postBody, _ := json.Marshal(postStruct)
client := &http.Client{}
2023-03-03 16:37:34 +00:00
u := fmt.Sprintf(endPoint, p.c.Get("gpt3.engine", "ada"))
2021-11-19 18:00:30 +00:00
req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody))
2021-11-19 16:24:20 +00:00
if err != nil {
log.Error().Err(err).Msg("could not make gpt3 request")
2023-03-03 16:37:34 +00:00
return nil, err
2021-11-19 16:24:20 +00:00
}
gpt3Key := p.c.Get("gpt3.bearer", "")
if gpt3Key == "" {
log.Error().Msgf("no GPT3 key given")
2023-03-03 16:37:34 +00:00
return nil, err
2021-11-19 16:24:20 +00:00
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
res, err := client.Do(req)
if err != nil {
2023-03-03 16:37:34 +00:00
return nil, err
2021-11-19 16:24:20 +00:00
}
resBody, _ := io.ReadAll(res.Body)
2021-11-19 16:24:20 +00:00
gpt3Resp := gpt3Response{}
err = json.Unmarshal(resBody, &gpt3Resp)
log.Debug().
Str("body", string(resBody)).
Interface("resp", gpt3Resp).
Msg("OpenAI Response")
2023-03-03 16:37:34 +00:00
return gpt3Resp, nil
}
2023-03-03 16:37:34 +00:00
func (p *GPTPlugin) checkStem(stem string) error {
2023-02-07 15:21:43 +00:00
if !p.c.GetBool("gpt3.moderation", true) {
return nil
}
postBody, _ := json.Marshal(gpt3ModRequest{Input: stem})
client := &http.Client{}
req, err := http.NewRequest("POST", gpt3ModURL, bytes.NewBuffer(postBody))
if err != nil {
return err
}
gpt3Key := p.c.Get("gpt3.bearer", "")
if gpt3Key == "" {
return fmt.Errorf("no GPT3 API key")
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
res, err := client.Do(req)
if err != nil {
return err
}
resBody, _ := io.ReadAll(res.Body)
log.Debug().Str("resBody", string(resBody)).Msg("res")
gpt3Resp := gpt3Moderation{}
err = json.Unmarshal(resBody, &gpt3Resp)
if err != nil {
return err
}
log.Debug().Interface("GPT3 Moderation", gpt3Resp).Msg("Moderation result")
for _, res := range gpt3Resp.Results {
if res.Flagged {
list := ""
categories := reflect.ValueOf(res.Categories)
fields := reflect.VisibleFields(reflect.TypeOf(res.Categories))
for i := 0; i < categories.NumField(); i++ {
if categories.Field(i).Bool() {
list += fields[i].Name + ", "
}
}
list = strings.TrimSuffix(list, ", ")
return fmt.Errorf("flagged: %s", list)
}
}
return nil
2021-11-19 16:24:20 +00:00
}
type gpt3Request struct {
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
N int `json:"n"`
Stream bool `json:"stream"`
Logprobs any `json:"logprobs"`
Stop []string `json:"stop"`
Echo bool `json:"echo"`
2021-11-19 16:24:20 +00:00
}
type gpt3ModRequest struct {
Input string `json:"input"`
}
2021-11-19 16:24:20 +00:00
type gpt3Response struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Text string `json:"text"`
Index int `json:"index"`
Logprobs any `json:"logprobs"`
FinishReason string `json:"finish_reason"`
2021-11-19 16:24:20 +00:00
} `json:"choices"`
}
type gpt3Moderation struct {
ID string `json:"id"`
Model string `json:"model"`
Results []struct {
Categories struct {
Hate bool `json:"hate"`
HateThreatening bool `json:"hate/threatening"`
SelfHarm bool `json:"self-harm"`
Sexual bool `json:"sexual"`
SexualMinors bool `json:"sexual/minors"`
Violence bool `json:"violence"`
ViolenceGraphic bool `json:"violence/graphic"`
} `json:"categories"`
CategoryScores struct {
Hate float64 `json:"hate"`
HateThreatening float64 `json:"hate/threatening"`
SelfHarm float64 `json:"self-harm"`
Sexual float64 `json:"sexual"`
SexualMinors float64 `json:"sexual/minors"`
Violence float64 `json:"violence"`
ViolenceGraphic float64 `json:"violence/graphic"`
} `json:"category_scores"`
Flagged bool `json:"flagged"`
} `json:"results"`
}