From eb67d1a35e43dc65be91306fe9864466ad606dfb Mon Sep 17 00:00:00 2001 From: Chris Sexton <3216719+chrissexton@users.noreply.github.com> Date: Tue, 7 Feb 2023 10:17:07 -0500 Subject: [PATCH] gpt3: moderation to protect skiesel's innocence --- plugins/gpt3/gpt3.go | 98 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 8 deletions(-) diff --git a/plugins/gpt3/gpt3.go b/plugins/gpt3/gpt3.go index bdcba83..7f350e0 100644 --- a/plugins/gpt3/gpt3.go +++ b/plugins/gpt3/gpt3.go @@ -4,10 +4,12 @@ import ( "bytes" "encoding/json" "fmt" - "io/ioutil" + "io" "math/rand" "net/http" + "reflect" "regexp" + "strings" "github.com/rs/zerolog/log" "github.com/velour/catbase/bot" @@ -15,6 +17,7 @@ import ( ) const gpt3URL = "https://api.openai.com/v1/engines/%s/completions" +const gpt3ModURL = "https://api.openai.com/v1/moderations" type GPT3Plugin struct { b bot.Bot @@ -46,7 +49,15 @@ func (p *GPT3Plugin) register() { func (p *GPT3Plugin) message(r bot.Request) bool { stem := r.Values["text"] + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, p.mkRequest(stem)) + return true +} + +func (p *GPT3Plugin) mkRequest(stem string) string { log.Debug().Msgf("Got GPT3 request: %s", stem) + if err := p.checkStem(stem); err != nil { + return "GPT3 Moderation " + err.Error() + } postStruct := gpt3Request{ Prompt: stem, MaxTokens: p.c.GetInt("gpt3.tokens", 16), @@ -62,22 +73,21 @@ func (p *GPT3Plugin) message(r bot.Request) bool { req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody)) if err != nil { log.Error().Err(err).Msg("could not make gpt3 request") - return false + return err.Error() } gpt3Key := p.c.Get("gpt3.bearer", "") if gpt3Key == "" { log.Error().Msgf("no GPT3 key given") - return false + return "No GPT3 API key" } req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key)) res, err := client.Do(req) if err != nil { - log.Error().Err(err).Msg("could not make gpt3 request") - return false + return err.Error() } - resBody, _ := ioutil.ReadAll(res.Body) + resBody, _ := io.ReadAll(res.Body) gpt3Resp := gpt3Response{} err = json.Unmarshal(resBody, &gpt3Resp) @@ -90,8 +100,50 @@ func (p *GPT3Plugin) message(r bot.Request) bool { if len(gpt3Resp.Choices) > 0 { msg = gpt3Resp.Choices[rand.Intn(len(gpt3Resp.Choices))].Text } - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) - return true + return msg +} + +func (p *GPT3Plugin) checkStem(stem string) error { + postBody, _ := json.Marshal(gpt3ModRequest{Input: stem}) + client := &http.Client{} + req, err := http.NewRequest("POST", gpt3ModURL, bytes.NewBuffer(postBody)) + if err != nil { + return err + } + gpt3Key := p.c.Get("gpt3.bearer", "") + if gpt3Key == "" { + return fmt.Errorf("no GPT3 API key") + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key)) + res, err := client.Do(req) + if err != nil { + return err + } + + resBody, _ := io.ReadAll(res.Body) + log.Debug().Str("resBody", string(resBody)).Msg("res") + gpt3Resp := gpt3Moderation{} + err = json.Unmarshal(resBody, &gpt3Resp) + if err != nil { + return err + } + log.Debug().Interface("GPT3 Moderation", gpt3Resp).Msg("Moderation result") + for _, res := range gpt3Resp.Results { + if res.Flagged { + list := "" + categories := reflect.ValueOf(res.Categories) + fields := reflect.VisibleFields(reflect.TypeOf(res.Categories)) + for i := 0; i < categories.NumField(); i++ { + if categories.Field(i).Bool() { + list += fields[i].Name + ", " + } + } + list = strings.TrimSuffix(list, ", ") + return fmt.Errorf("flagged: %s", list) + } + } + return nil } type gpt3Request struct { @@ -106,6 +158,10 @@ type gpt3Request struct { Echo bool `json:"echo"` } +type gpt3ModRequest struct { + Input string `json:"input"` +} + type gpt3Response struct { ID string `json:"id"` Object string `json:"object"` @@ -118,3 +174,29 @@ type gpt3Response struct { FinishReason string `json:"finish_reason"` } `json:"choices"` } + +type gpt3Moderation struct { + ID string `json:"id"` + Model string `json:"model"` + Results []struct { + Categories struct { + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + SelfHarm bool `json:"self-harm"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` + } `json:"categories"` + CategoryScores struct { + Hate float64 `json:"hate"` + HateThreatening float64 `json:"hate/threatening"` + SelfHarm float64 `json:"self-harm"` + Sexual float64 `json:"sexual"` + SexualMinors float64 `json:"sexual/minors"` + Violence float64 `json:"violence"` + ViolenceGraphic float64 `json:"violence/graphic"` + } `json:"category_scores"` + Flagged bool `json:"flagged"` + } `json:"results"` +}