Compare commits

..

No commits in common. "68738f847b3dff439ee442d3a9e8cea057762c84" and "91d21c10766d8578705ee8152ba2a85446320ed6" have entirely different histories.

1 changed files with 8 additions and 93 deletions

View File

@ -4,12 +4,10 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"reflect"
"regexp" "regexp"
"strings"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/velour/catbase/bot" "github.com/velour/catbase/bot"
@ -17,7 +15,6 @@ import (
) )
const gpt3URL = "https://api.openai.com/v1/engines/%s/completions" const gpt3URL = "https://api.openai.com/v1/engines/%s/completions"
const gpt3ModURL = "https://api.openai.com/v1/moderations"
type GPT3Plugin struct { type GPT3Plugin struct {
b bot.Bot b bot.Bot
@ -49,15 +46,7 @@ func (p *GPT3Plugin) register() {
func (p *GPT3Plugin) message(r bot.Request) bool { func (p *GPT3Plugin) message(r bot.Request) bool {
stem := r.Values["text"] stem := r.Values["text"]
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, p.mkRequest(stem))
return true
}
func (p *GPT3Plugin) mkRequest(stem string) string {
log.Debug().Msgf("Got GPT3 request: %s", stem) log.Debug().Msgf("Got GPT3 request: %s", stem)
if err := p.checkStem(stem); err != nil {
return "GPT3 moderation " + err.Error()
}
postStruct := gpt3Request{ postStruct := gpt3Request{
Prompt: stem, Prompt: stem,
MaxTokens: p.c.GetInt("gpt3.tokens", 16), MaxTokens: p.c.GetInt("gpt3.tokens", 16),
@ -73,21 +62,22 @@ func (p *GPT3Plugin) mkRequest(stem string) string {
req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody)) req, err := http.NewRequest("POST", u, bytes.NewBuffer(postBody))
if err != nil { if err != nil {
log.Error().Err(err).Msg("could not make gpt3 request") log.Error().Err(err).Msg("could not make gpt3 request")
return err.Error() return false
} }
gpt3Key := p.c.Get("gpt3.bearer", "") gpt3Key := p.c.Get("gpt3.bearer", "")
if gpt3Key == "" { if gpt3Key == "" {
log.Error().Msgf("no GPT3 key given") log.Error().Msgf("no GPT3 key given")
return "No GPT3 API key" return false
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key)) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return err.Error() log.Error().Err(err).Msg("could not make gpt3 request")
return false
} }
resBody, _ := io.ReadAll(res.Body) resBody, _ := ioutil.ReadAll(res.Body)
gpt3Resp := gpt3Response{} gpt3Resp := gpt3Response{}
err = json.Unmarshal(resBody, &gpt3Resp) err = json.Unmarshal(resBody, &gpt3Resp)
@ -100,53 +90,8 @@ func (p *GPT3Plugin) mkRequest(stem string) string {
if len(gpt3Resp.Choices) > 0 { if len(gpt3Resp.Choices) > 0 {
msg = gpt3Resp.Choices[rand.Intn(len(gpt3Resp.Choices))].Text msg = gpt3Resp.Choices[rand.Intn(len(gpt3Resp.Choices))].Text
} }
return msg p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
} return true
func (p *GPT3Plugin) checkStem(stem string) error {
if !p.c.GetBool("gpt3.moderation", true) {
return nil
}
postBody, _ := json.Marshal(gpt3ModRequest{Input: stem})
client := &http.Client{}
req, err := http.NewRequest("POST", gpt3ModURL, bytes.NewBuffer(postBody))
if err != nil {
return err
}
gpt3Key := p.c.Get("gpt3.bearer", "")
if gpt3Key == "" {
return fmt.Errorf("no GPT3 API key")
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", gpt3Key))
res, err := client.Do(req)
if err != nil {
return err
}
resBody, _ := io.ReadAll(res.Body)
log.Debug().Str("resBody", string(resBody)).Msg("res")
gpt3Resp := gpt3Moderation{}
err = json.Unmarshal(resBody, &gpt3Resp)
if err != nil {
return err
}
log.Debug().Interface("GPT3 Moderation", gpt3Resp).Msg("Moderation result")
for _, res := range gpt3Resp.Results {
if res.Flagged {
list := ""
categories := reflect.ValueOf(res.Categories)
fields := reflect.VisibleFields(reflect.TypeOf(res.Categories))
for i := 0; i < categories.NumField(); i++ {
if categories.Field(i).Bool() {
list += fields[i].Name + ", "
}
}
list = strings.TrimSuffix(list, ", ")
return fmt.Errorf("flagged: %s", list)
}
}
return nil
} }
type gpt3Request struct { type gpt3Request struct {
@ -161,10 +106,6 @@ type gpt3Request struct {
Echo bool `json:"echo"` Echo bool `json:"echo"`
} }
type gpt3ModRequest struct {
Input string `json:"input"`
}
type gpt3Response struct { type gpt3Response struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@ -177,29 +118,3 @@ type gpt3Response struct {
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} `json:"choices"` } `json:"choices"`
} }
type gpt3Moderation struct {
ID string `json:"id"`
Model string `json:"model"`
Results []struct {
Categories struct {
Hate bool `json:"hate"`
HateThreatening bool `json:"hate/threatening"`
SelfHarm bool `json:"self-harm"`
Sexual bool `json:"sexual"`
SexualMinors bool `json:"sexual/minors"`
Violence bool `json:"violence"`
ViolenceGraphic bool `json:"violence/graphic"`
} `json:"categories"`
CategoryScores struct {
Hate float64 `json:"hate"`
HateThreatening float64 `json:"hate/threatening"`
SelfHarm float64 `json:"self-harm"`
Sexual float64 `json:"sexual"`
SexualMinors float64 `json:"sexual/minors"`
Violence float64 `json:"violence"`
ViolenceGraphic float64 `json:"violence/graphic"`
} `json:"category_scores"`
Flagged bool `json:"flagged"`
} `json:"results"`
}