rest: bypass jq when it's convenient

Adding a custom parser for GPT-2 queries because it's easier this way.
This commit is contained in:
Chris Sexton 2021-09-03 09:43:37 -04:00 committed by Chris Sexton
parent da69df36d4
commit 646282b233
1 changed files with 39 additions and 17 deletions

View File

@ -28,6 +28,24 @@ type RestPlugin struct {
handlers bot.HandlerTable handlers bot.HandlerTable
} }
type postProcessor func(interface{}) string
var postProcessors = map[string]postProcessor{
"gpt2": func(input interface{}) string {
values := input.(map[string]interface{})
text := values["text"].(string)
lastStop := strings.LastIndexAny(text, ".!?")
if lastStop > 0 {
text = text[:lastStop+1]
}
eot := strings.LastIndex(text, "<|endoftext|>")
if eot > 0 {
text = text[:eot]
}
return text
},
}
func New(b bot.Bot) *RestPlugin { func New(b bot.Bot) *RestPlugin {
p := &RestPlugin{ p := &RestPlugin{
b: b, b: b,
@ -298,6 +316,10 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
var returnValues interface{} var returnValues interface{}
json.Unmarshal(body, &returnValues) json.Unmarshal(body, &returnValues)
var msg string
if pp, ok := postProcessors[w.ReturnField]; ok {
msg = pp(returnValues)
} else {
query, err := gojq.Parse(w.ReturnField) query, err := gojq.Parse(w.ReturnField)
if err != nil { if err != nil {
msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField) msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField)
@ -305,7 +327,6 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
return true return true
} }
var msg string
iter := query.Run(returnValues) // or query.RunWithContext iter := query.Run(returnValues) // or query.RunWithContext
for { for {
@ -318,6 +339,7 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
} }
msg += fmt.Sprintf("%s\n", v) msg += fmt.Sprintf("%s\n", v)
} }
}
msg = strings.TrimSpace(msg) msg = strings.TrimSpace(msg)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
@ -327,7 +349,7 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
func (p *RestPlugin) handleErr(err error, r bot.Request) bool { func (p *RestPlugin) handleErr(err error, r bot.Request) bool {
if err != nil { if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: %s", err) p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Error: %s", err))
return true return true
} }
return false return false