rest: bypass jq when it's convenient

Adding a custom parser for GPT-2 queries because it's easier this way.
This commit is contained in:
Chris Sexton 2021-09-03 09:43:37 -04:00 committed by Chris Sexton
parent da69df36d4
commit 646282b233
1 changed files with 39 additions and 17 deletions

View File

@ -28,6 +28,24 @@ type RestPlugin struct {
handlers bot.HandlerTable handlers bot.HandlerTable
} }
type postProcessor func(interface{}) string
var postProcessors = map[string]postProcessor{
"gpt2": func(input interface{}) string {
values := input.(map[string]interface{})
text := values["text"].(string)
lastStop := strings.LastIndexAny(text, ".!?")
if lastStop > 0 {
text = text[:lastStop+1]
}
eot := strings.LastIndex(text, "<|endoftext|>")
if eot > 0 {
text = text[:eot]
}
return text
},
}
func New(b bot.Bot) *RestPlugin { func New(b bot.Bot) *RestPlugin {
p := &RestPlugin{ p := &RestPlugin{
b: b, b: b,
@ -298,25 +316,29 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
var returnValues interface{} var returnValues interface{}
json.Unmarshal(body, &returnValues) json.Unmarshal(body, &returnValues)
query, err := gojq.Parse(w.ReturnField)
if err != nil {
msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
return true
}
var msg string var msg string
iter := query.Run(returnValues) // or query.RunWithContext if pp, ok := postProcessors[w.ReturnField]; ok {
msg = pp(returnValues)
} else {
query, err := gojq.Parse(w.ReturnField)
if err != nil {
msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField)
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
return true
}
for { iter := query.Run(returnValues) // or query.RunWithContext
v, ok := iter.Next()
if !ok { for {
break v, ok := iter.Next()
if !ok {
break
}
if err, ok := v.(error); ok {
return p.handleErr(err, r)
}
msg += fmt.Sprintf("%s\n", v)
} }
if err, ok := v.(error); ok {
return p.handleErr(err, r)
}
msg += fmt.Sprintf("%s\n", v)
} }
msg = strings.TrimSpace(msg) msg = strings.TrimSpace(msg)
@ -327,7 +349,7 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
func (p *RestPlugin) handleErr(err error, r bot.Request) bool { func (p *RestPlugin) handleErr(err error, r bot.Request) bool {
if err != nil { if err != nil {
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: %s", err) p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Error: %s", err))
return true return true
} }
return false return false