From 646282b233f101b9761f927a0f7efaa6a6706b38 Mon Sep 17 00:00:00 2001 From: Chris Sexton Date: Fri, 3 Sep 2021 09:43:37 -0400 Subject: [PATCH] rest: bypass jq when it's convenient Adding a custom parser for GPT-2 queries because it's easier this way. --- plugins/rest/rest.go | 56 ++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/plugins/rest/rest.go b/plugins/rest/rest.go index 051e11c..3f03347 100644 --- a/plugins/rest/rest.go +++ b/plugins/rest/rest.go @@ -28,6 +28,24 @@ type RestPlugin struct { handlers bot.HandlerTable } +type postProcessor func(interface{}) string + +var postProcessors = map[string]postProcessor{ + "gpt2": func(input interface{}) string { + values := input.(map[string]interface{}) + text := values["text"].(string) + lastStop := strings.LastIndexAny(text, ".!?") + if lastStop > 0 { + text = text[:lastStop+1] + } + eot := strings.LastIndex(text, "<|endoftext|>") + if eot > 0 { + text = text[:eot] + } + return text + }, +} + func New(b bot.Bot) *RestPlugin { p := &RestPlugin{ b: b, @@ -298,25 +316,29 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler { var returnValues interface{} json.Unmarshal(body, &returnValues) - query, err := gojq.Parse(w.ReturnField) - if err != nil { - msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField) - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) - return true - } - var msg string - iter := query.Run(returnValues) // or query.RunWithContext + if pp, ok := postProcessors[w.ReturnField]; ok { + msg = pp(returnValues) + } else { + query, err := gojq.Parse(w.ReturnField) + if err != nil { + msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) + return true + } - for { - v, ok := iter.Next() - if !ok { - break + iter := query.Run(returnValues) // or query.RunWithContext + + for { + v, ok := iter.Next() + if !ok { + break + } + if err, ok := v.(error); ok { + return p.handleErr(err, r) + } + msg += fmt.Sprintf("%s\n", v) } - if err, ok := v.(error); ok { - return p.handleErr(err, r) - } - msg += fmt.Sprintf("%s\n", v) } msg = strings.TrimSpace(msg) @@ -327,7 +349,7 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler { func (p *RestPlugin) handleErr(err error, r bot.Request) bool { if err != nil { - p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: %s", err) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Error: %s", err)) return true } return false