mirror of https://github.com/velour/catbase.git
rest: bypass jq when it's convenient
Adding a custom parser for GPT-2 queries because it's easier this way.
This commit is contained in:
parent
da69df36d4
commit
646282b233
|
@ -28,6 +28,24 @@ type RestPlugin struct {
|
||||||
handlers bot.HandlerTable
|
handlers bot.HandlerTable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type postProcessor func(interface{}) string
|
||||||
|
|
||||||
|
var postProcessors = map[string]postProcessor{
|
||||||
|
"gpt2": func(input interface{}) string {
|
||||||
|
values := input.(map[string]interface{})
|
||||||
|
text := values["text"].(string)
|
||||||
|
lastStop := strings.LastIndexAny(text, ".!?")
|
||||||
|
if lastStop > 0 {
|
||||||
|
text = text[:lastStop+1]
|
||||||
|
}
|
||||||
|
eot := strings.LastIndex(text, "<|endoftext|>")
|
||||||
|
if eot > 0 {
|
||||||
|
text = text[:eot]
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func New(b bot.Bot) *RestPlugin {
|
func New(b bot.Bot) *RestPlugin {
|
||||||
p := &RestPlugin{
|
p := &RestPlugin{
|
||||||
b: b,
|
b: b,
|
||||||
|
@ -298,25 +316,29 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
|
||||||
var returnValues interface{}
|
var returnValues interface{}
|
||||||
json.Unmarshal(body, &returnValues)
|
json.Unmarshal(body, &returnValues)
|
||||||
|
|
||||||
query, err := gojq.Parse(w.ReturnField)
|
|
||||||
if err != nil {
|
|
||||||
msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField)
|
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg string
|
var msg string
|
||||||
iter := query.Run(returnValues) // or query.RunWithContext
|
if pp, ok := postProcessors[w.ReturnField]; ok {
|
||||||
|
msg = pp(returnValues)
|
||||||
|
} else {
|
||||||
|
query, err := gojq.Parse(w.ReturnField)
|
||||||
|
if err != nil {
|
||||||
|
msg := fmt.Sprintf("Wire handler did not find return value: %s => `%s`", w.URL, w.ReturnField)
|
||||||
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
iter := query.Run(returnValues) // or query.RunWithContext
|
||||||
v, ok := iter.Next()
|
|
||||||
if !ok {
|
for {
|
||||||
break
|
v, ok := iter.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err, ok := v.(error); ok {
|
||||||
|
return p.handleErr(err, r)
|
||||||
|
}
|
||||||
|
msg += fmt.Sprintf("%s\n", v)
|
||||||
}
|
}
|
||||||
if err, ok := v.(error); ok {
|
|
||||||
return p.handleErr(err, r)
|
|
||||||
}
|
|
||||||
msg += fmt.Sprintf("%s\n", v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msg = strings.TrimSpace(msg)
|
msg = strings.TrimSpace(msg)
|
||||||
|
@ -327,7 +349,7 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler {
|
||||||
|
|
||||||
func (p *RestPlugin) handleErr(err error, r bot.Request) bool {
|
func (p *RestPlugin) handleErr(err error, r bot.Request) bool {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "Error: %s", err)
|
p.b.Send(r.Conn, bot.Message, r.Msg.Channel, fmt.Sprintf("Error: %s", err))
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
|
Loading…
Reference in New Issue