diff --git a/plugins/rest/rest.go b/plugins/rest/rest.go index 3f03347..2dccf6c 100644 --- a/plugins/rest/rest.go +++ b/plugins/rest/rest.go @@ -2,6 +2,7 @@ package rest import ( "bytes" + "crypto/sha512" "database/sql" "encoding/json" "errors" @@ -64,7 +65,8 @@ func (p *RestPlugin) setupDB() { id integer primary key autoincrement, url text not null, parse_regex text not null, - return_field text not null + return_field text not null, + body text not null )`) if err := tx.Commit(); err != nil { panic(err) @@ -85,10 +87,18 @@ func (p *RestPlugin) register() { Regex: regexp.MustCompile("(?i)^testwire `(?P[^`]+)` to (?P\\S+) `(?P[^`]+)` => (?P.*)$"), HelpText: "Tests a new REST function", Handler: p.handleTestWire}, + bot.HandlerSpec{Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile("(?i)^testwirepost `(?P[^`]+)` to (?P\\S+) with body `(?P[^`]+)` `(?P[^`]+)` => (?P.*)$"), + HelpText: "Tests a new REST function", + Handler: p.handleTestWire}, bot.HandlerSpec{Kind: bot.Message, IsCmd: true, Regex: regexp.MustCompile("(?i)^wire `(?P[^`]+)` to (?P\\S+) `(?P[^`]+)`$"), HelpText: "Registers a new REST function", Handler: p.handleWire}, + bot.HandlerSpec{Kind: bot.Message, IsCmd: true, + Regex: regexp.MustCompile("(?i)^wirepost `(?P[^`]+)` to (?P\\S+) with body `(?P[^`]+)` `(?P[^`]+)`$"), + HelpText: "Registers a new REST function", + Handler: p.handleWire}, } p.b.RegisterTable(p, p.handlers) wires, err := p.getWires() @@ -98,6 +108,15 @@ func (p *RestPlugin) register() { for _, w := range wires { p.b.RegisterRegex(p, bot.Message, w.ParseRegex.Regexp, p.mkHandler(w)) } + p.b.RegisterRegex(p, bot.Help, regexp.MustCompile(`.*`), func(r bot.Request) bool { + out := "Rest commands:\n" + for _, h := range p.handlers { + out += fmt.Sprintf("```%s```\t%s", h.Regex.String(), h.HelpText) + } + out = strings.TrimSpace(out) + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, out) + return true + }) } type ScanableRegexp struct { @@ -149,6 +168,8 @@ type wire struct { ParseRegex ScanableRegexp `db:"parse_regex"` // The JSON field that will contain the REST return value ReturnField string `db:"return_field"` + // Body parse for POST + Body string `db:"body"` } func (w wire) String() string { @@ -228,15 +249,22 @@ func (p *RestPlugin) rmWire(r bot.Request) bool { func (p *RestPlugin) mkWire(r bot.Request) (wire, error) { var w wire var err error + + // if it's not a POST, this will be empty (and indicate we will use GET) + w.Body = r.Values["body"] + w.ParseRegex.Regexp, err = regexp.Compile(r.Values["parse"]) if err != nil { return w, err } + w.URL.URL, err = url.Parse(r.Values["url"]) if err != nil { return w, err } + w.ReturnField = r.Values["returnField"] + return w, nil } @@ -245,6 +273,10 @@ func (p *RestPlugin) handleWire(r bot.Request) bool { var msg string var err error w, err = p.mkWire(r) + if err != nil { + msg = err.Error() + goto SEND + } err = w.Save(p.db) if err != nil { msg = err.Error() @@ -294,14 +326,42 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler { return true } newURL, err := url.Parse(buf.String()) + if p.handleErr(err, r) { + return true + } + bodyTpl := w.Body + parse, err = template.New(bodyTpl).Parse(bodyTpl) + if p.handleErr(err, r) { + return true + } + bodyBuf := bytes.Buffer{} + err = parse.Execute(&bodyBuf, values) + if p.handleErr(err, r) { + return true + } + // Note, we will expect this is formatted good enough to send over the wire (?) log.Debug(). Interface("values", values). - Str("URL", buf.String()). + Str("URL", newURL.String()). + Str("body", bodyBuf.String()). Msg("Querying URL with values") if p.handleErr(err, r) { return true } - resp, err := http.Get(newURL.String()) + client := &http.Client{} + method := "GET" + if bodyBuf.Len() > 0 { + method = "POST" + } + req, err := http.NewRequest(method, newURL.String(), bytes.NewReader(bodyBuf.Bytes())) + if p.handleErr(err, r) { + return true + } + key := fmt.Sprintf("rest.%x.origin", sha512.Sum512([]byte(urlStr))) + if origin := p.b.Config().Get(key, ""); origin != "" { + req.Header.Add("Origin", origin) + } + resp, err := client.Do(req) if p.handleErr(err, r) { return true } @@ -313,6 +373,7 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler { if p.handleErr(err, r) { return true } + log.Debug().Str("body", string(body)).Msg("got a body back") var returnValues interface{} json.Unmarshal(body, &returnValues) @@ -342,6 +403,10 @@ func (p *RestPlugin) mkHandler(w wire) bot.ResponseHandler { } msg = strings.TrimSpace(msg) + if msg == "" { + p.b.Send(r.Conn, bot.Message, r.Msg.Channel, "we got an empty body after all we've done") + return true + } p.b.Send(r.Conn, bot.Message, r.Msg.Channel, msg) return true }