2021-05-19 15:46:23 +00:00
package rest
import (
"bytes"
2021-09-16 22:20:25 +00:00
"crypto/sha512"
2021-05-19 15:46:23 +00:00
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"text/template"
2021-05-19 17:58:28 +00:00
"github.com/itchyny/gojq"
2021-05-20 13:59:28 +00:00
"github.com/rs/zerolog/log"
2021-05-19 17:58:28 +00:00
2021-05-19 15:46:23 +00:00
"github.com/jmoiron/sqlx"
"github.com/velour/catbase/bot"
)
type RestPlugin struct {
b bot . Bot
db * sqlx . DB
handlers bot . HandlerTable
}
2022-03-22 01:32:44 +00:00
type postProcessor func ( any ) string
2021-09-03 13:43:37 +00:00
var postProcessors = map [ string ] postProcessor {
2022-03-22 01:32:44 +00:00
"gpt2" : func ( input any ) string {
values := input . ( map [ string ] any )
2021-09-03 13:43:37 +00:00
text := values [ "text" ] . ( string )
lastStop := strings . LastIndexAny ( text , ".!?" )
if lastStop > 0 {
text = text [ : lastStop + 1 ]
}
eot := strings . LastIndex ( text , "<|endoftext|>" )
if eot > 0 {
text = text [ : eot ]
}
return text
} ,
}
2021-05-19 15:46:23 +00:00
func New ( b bot . Bot ) * RestPlugin {
p := & RestPlugin {
b : b ,
db : b . DB ( ) ,
handlers : bot . HandlerTable { } ,
}
p . setupDB ( )
p . register ( )
return p
}
func ( p * RestPlugin ) setupDB ( ) {
tx := p . db . MustBegin ( )
tx . MustExec ( `
create table if not exists wires (
id integer primary key autoincrement ,
url text not null ,
parse_regex text not null ,
2021-09-16 22:20:25 +00:00
return_field text not null ,
body text not null
2021-05-19 15:46:23 +00:00
) ` )
if err := tx . Commit ( ) ; err != nil {
panic ( err )
}
}
func ( p * RestPlugin ) register ( ) {
p . handlers = bot . HandlerTable {
bot . HandlerSpec { Kind : bot . Message , IsCmd : true ,
Regex : regexp . MustCompile ( "(?i)^list wires$" ) ,
HelpText : "Lists all REST functions" ,
Handler : p . listWires } ,
bot . HandlerSpec { Kind : bot . Message , IsCmd : true ,
Regex : regexp . MustCompile ( ` (?i)^rm wire (?P<id>\d+)$ ` ) ,
HelpText : "Removes a wire by ID (use list to view)" ,
Handler : p . rmWire } ,
2021-05-20 22:25:58 +00:00
bot . HandlerSpec { Kind : bot . Message , IsCmd : true ,
Regex : regexp . MustCompile ( "(?i)^testwire `(?P<parse>[^`]+)` to (?P<url>\\S+) `(?P<returnField>[^`]+)` => (?P<text>.*)$" ) ,
HelpText : "Tests a new REST function" ,
Handler : p . handleTestWire } ,
2021-09-16 22:20:25 +00:00
bot . HandlerSpec { Kind : bot . Message , IsCmd : true ,
Regex : regexp . MustCompile ( "(?i)^testwirepost `(?P<parse>[^`]+)` to (?P<url>\\S+) with body `(?P<body>[^`]+)` `(?P<returnField>[^`]+)` => (?P<text>.*)$" ) ,
HelpText : "Tests a new REST function" ,
Handler : p . handleTestWire } ,
2021-05-19 15:46:23 +00:00
bot . HandlerSpec { Kind : bot . Message , IsCmd : true ,
Regex : regexp . MustCompile ( "(?i)^wire `(?P<parse>[^`]+)` to (?P<url>\\S+) `(?P<returnField>[^`]+)`$" ) ,
HelpText : "Registers a new REST function" ,
Handler : p . handleWire } ,
2021-09-16 22:20:25 +00:00
bot . HandlerSpec { Kind : bot . Message , IsCmd : true ,
Regex : regexp . MustCompile ( "(?i)^wirepost `(?P<parse>[^`]+)` to (?P<url>\\S+) with body `(?P<body>[^`]+)` `(?P<returnField>[^`]+)`$" ) ,
HelpText : "Registers a new REST function" ,
Handler : p . handleWire } ,
2021-05-19 15:46:23 +00:00
}
p . b . RegisterTable ( p , p . handlers )
wires , err := p . getWires ( )
if err != nil {
panic ( err )
}
for _ , w := range wires {
p . b . RegisterRegex ( p , bot . Message , w . ParseRegex . Regexp , p . mkHandler ( w ) )
}
2021-09-16 22:20:25 +00:00
p . b . RegisterRegex ( p , bot . Help , regexp . MustCompile ( ` .* ` ) , func ( r bot . Request ) bool {
out := "Rest commands:\n"
for _ , h := range p . handlers {
out += fmt . Sprintf ( "```%s```\t%s" , h . Regex . String ( ) , h . HelpText )
}
out = strings . TrimSpace ( out )
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , out )
return true
} )
2021-05-19 15:46:23 +00:00
}
type ScanableRegexp struct {
* regexp . Regexp
}
2022-03-22 01:32:44 +00:00
func ( s * ScanableRegexp ) Scan ( src any ) error {
2021-05-19 15:46:23 +00:00
var source string
switch src . ( type ) {
case string :
source = src . ( string )
default :
return errors . New ( "incompatible type for ScanableRegexp" )
}
r , err := regexp . Compile ( source )
if err != nil {
return err
}
s . Regexp = r
return nil
}
type ScanableURL struct {
* url . URL
}
2022-03-22 01:32:44 +00:00
func ( s * ScanableURL ) Scan ( src any ) error {
2021-05-19 15:46:23 +00:00
var source string
switch src . ( type ) {
case string :
source = src . ( string )
default :
return errors . New ( "incompatible type for ScanableURL" )
}
u , err := url . Parse ( source )
if err != nil {
return err
}
s . URL = u
return nil
}
type wire struct {
// ID
ID sql . NullInt64
// The URL to make a request to
URL ScanableURL
// The regex which will trigger this REST action
ParseRegex ScanableRegexp ` db:"parse_regex" `
// The JSON field that will contain the REST return value
ReturnField string ` db:"return_field" `
2021-09-16 22:20:25 +00:00
// Body parse for POST
Body string ` db:"body" `
2021-05-19 15:46:23 +00:00
}
func ( w wire ) String ( ) string {
msg := "Wire:"
msg += fmt . Sprintf ( "\nURL: %s" , w . URL )
msg += fmt . Sprintf ( "\nParsing to trigger: `%s`" , w . ParseRegex )
msg += fmt . Sprintf ( "\nReturn field: `%s`" , w . ReturnField )
return msg
}
2021-05-20 22:25:58 +00:00
func ( p * RestPlugin ) getWires ( ) ( [ ] wire , error ) {
wires := [ ] wire { }
2021-05-19 15:46:23 +00:00
err := p . db . Select ( & wires , ` select * from wires ` )
return wires , err
}
func ( p * RestPlugin ) deleteWire ( id int64 ) error {
_ , err := p . db . Exec ( ` delete from wires where id=? ` , id )
return err
}
func ( w * wire ) Update ( db * sqlx . DB ) error {
if ! w . ID . Valid {
return w . Save ( db )
}
id , _ := w . ID . Value ( )
_ , err := db . Exec ( ` update wires set url=?, parse_regex=?, return_field=? where id=? ` ,
w . URL . String ( ) , w . ParseRegex . String ( ) , w . ReturnField , id )
return err
}
func ( w * wire ) Save ( db * sqlx . DB ) error {
if w . ID . Valid {
return w . Update ( db )
}
res , err := db . Exec ( ` insert into wires (url, parse_regex, return_field) values (?, ?, ?) ` ,
w . URL . String ( ) , w . ParseRegex . String ( ) , w . ReturnField )
if err != nil {
return err
}
id , err := res . LastInsertId ( )
if err != nil {
return err
}
_ = w . ID . Scan ( id )
return nil
}
func ( p * RestPlugin ) listWires ( r bot . Request ) bool {
var msg string
wires , err := p . getWires ( )
if err != nil {
msg = err . Error ( )
goto SEND
}
msg = "Current wires:"
for _ , w := range wires {
id , _ := w . ID . Value ( )
msg += fmt . Sprintf ( "\n\t%d: `%s` => %s" , id , w . ParseRegex , w . URL )
}
SEND :
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , msg )
return true
}
func ( p * RestPlugin ) rmWire ( r bot . Request ) bool {
id , _ := strconv . ParseInt ( r . Values [ "id" ] , 10 , 64 )
err := p . deleteWire ( id )
if err != nil {
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , "Could not delete wire: " + err . Error ( ) )
return true
}
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , fmt . Sprintf ( "Deleted wire: %d" , id ) )
return true
}
2021-05-20 22:25:58 +00:00
func ( p * RestPlugin ) mkWire ( r bot . Request ) ( wire , error ) {
2021-05-19 15:46:23 +00:00
var w wire
var err error
2021-09-16 22:20:25 +00:00
// if it's not a POST, this will be empty (and indicate we will use GET)
w . Body = r . Values [ "body" ]
2021-05-19 15:46:23 +00:00
w . ParseRegex . Regexp , err = regexp . Compile ( r . Values [ "parse" ] )
if err != nil {
2021-05-20 22:25:58 +00:00
return w , err
2021-05-19 15:46:23 +00:00
}
2021-09-16 22:20:25 +00:00
2021-05-19 15:46:23 +00:00
w . URL . URL , err = url . Parse ( r . Values [ "url" ] )
if err != nil {
2021-05-20 22:25:58 +00:00
return w , err
2021-05-19 15:46:23 +00:00
}
2021-09-16 22:20:25 +00:00
2021-05-19 15:46:23 +00:00
w . ReturnField = r . Values [ "returnField" ]
2021-09-16 22:20:25 +00:00
2021-05-20 22:25:58 +00:00
return w , nil
}
func ( p * RestPlugin ) handleWire ( r bot . Request ) bool {
var w wire
var msg string
var err error
w , err = p . mkWire ( r )
2021-09-16 22:20:25 +00:00
if err != nil {
msg = err . Error ( )
goto SEND
}
2021-05-19 15:46:23 +00:00
err = w . Save ( p . db )
if err != nil {
msg = err . Error ( )
goto SEND
}
2021-05-20 22:25:58 +00:00
p . b . RegisterRegex ( p , bot . Message , w . ParseRegex . Regexp , p . mkHandler ( w ) )
2021-05-19 15:46:23 +00:00
msg = fmt . Sprintf ( "Saved %s" , w )
SEND :
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , msg )
return true
}
2021-05-20 22:25:58 +00:00
func ( p * RestPlugin ) handleTestWire ( r bot . Request ) bool {
text := r . Values [ "text" ]
w , err := p . mkWire ( r )
if err != nil {
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , err )
return true
}
h := p . mkHandler ( w )
r . Values = bot . ParseValues ( w . ParseRegex . Regexp , text )
return h ( r )
}
func ( p * RestPlugin ) mkHandler ( w wire ) bot . ResponseHandler {
2021-05-19 15:46:23 +00:00
return func ( r bot . Request ) bool {
if r . Msg . User . Name == p . b . Config ( ) . GetString ( "nick" , "" ) {
return false
}
values := bot . RegexValues { }
2021-05-20 13:59:28 +00:00
for _ , s := range p . b . Config ( ) . GetAllSecrets ( ) {
values [ s . Key ] = s . Value
}
log . Debug ( ) . Interface ( "values" , values ) . Msgf ( "secrets" )
2021-05-19 15:46:23 +00:00
for k := range r . Values {
values [ k ] = url . QueryEscape ( r . Values [ k ] )
}
2021-05-20 13:59:28 +00:00
log . Debug ( ) . Interface ( "values" , values ) . Msgf ( "r.Values" )
2021-05-19 15:46:23 +00:00
urlStr := w . URL . String ( )
parse , err := template . New ( urlStr ) . Parse ( urlStr )
if p . handleErr ( err , r ) {
return true
}
buf := bytes . Buffer { }
err = parse . Execute ( & buf , values )
if p . handleErr ( err , r ) {
return true
}
newURL , err := url . Parse ( buf . String ( ) )
2021-09-16 22:20:25 +00:00
if p . handleErr ( err , r ) {
return true
}
bodyTpl := w . Body
parse , err = template . New ( bodyTpl ) . Parse ( bodyTpl )
if p . handleErr ( err , r ) {
return true
}
bodyBuf := bytes . Buffer { }
err = parse . Execute ( & bodyBuf , values )
if p . handleErr ( err , r ) {
return true
}
// Note, we will expect this is formatted good enough to send over the wire (?)
2021-05-20 13:59:28 +00:00
log . Debug ( ) .
Interface ( "values" , values ) .
2021-09-16 22:20:25 +00:00
Str ( "URL" , newURL . String ( ) ) .
Str ( "body" , bodyBuf . String ( ) ) .
2021-05-20 13:59:28 +00:00
Msg ( "Querying URL with values" )
2021-05-19 15:46:23 +00:00
if p . handleErr ( err , r ) {
return true
}
2021-09-16 22:20:25 +00:00
client := & http . Client { }
method := "GET"
if bodyBuf . Len ( ) > 0 {
method = "POST"
}
req , err := http . NewRequest ( method , newURL . String ( ) , bytes . NewReader ( bodyBuf . Bytes ( ) ) )
if p . handleErr ( err , r ) {
return true
}
key := fmt . Sprintf ( "rest.%x.origin" , sha512 . Sum512 ( [ ] byte ( urlStr ) ) )
if origin := p . b . Config ( ) . Get ( key , "" ) ; origin != "" {
req . Header . Add ( "Origin" , origin )
}
resp , err := client . Do ( req )
2021-05-19 15:46:23 +00:00
if p . handleErr ( err , r ) {
return true
}
if resp . StatusCode < 200 || resp . StatusCode >= 300 {
2021-05-20 22:25:58 +00:00
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , fmt . Sprintf ( "Got a status %d: %s" ,
resp . StatusCode , resp . Status ) )
2021-05-19 15:46:23 +00:00
}
body , err := ioutil . ReadAll ( resp . Body )
if p . handleErr ( err , r ) {
return true
}
2021-09-16 22:20:25 +00:00
log . Debug ( ) . Str ( "body" , string ( body ) ) . Msg ( "got a body back" )
2022-03-22 01:32:44 +00:00
var returnValues any
2021-05-19 15:46:23 +00:00
json . Unmarshal ( body , & returnValues )
2021-05-19 17:58:28 +00:00
2021-05-19 15:46:23 +00:00
var msg string
2021-09-03 13:43:37 +00:00
if pp , ok := postProcessors [ w . ReturnField ] ; ok {
msg = pp ( returnValues )
} else {
query , err := gojq . Parse ( w . ReturnField )
if err != nil {
msg := fmt . Sprintf ( "Wire handler did not find return value: %s => `%s`" , w . URL , w . ReturnField )
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , msg )
return true
2021-05-19 17:58:28 +00:00
}
2021-09-03 13:43:37 +00:00
iter := query . Run ( returnValues ) // or query.RunWithContext
for {
v , ok := iter . Next ( )
if ! ok {
break
}
if err , ok := v . ( error ) ; ok {
return p . handleErr ( err , r )
}
msg += fmt . Sprintf ( "%s\n" , v )
2021-05-19 17:58:28 +00:00
}
2021-05-19 15:46:23 +00:00
}
2021-05-19 17:58:28 +00:00
msg = strings . TrimSpace ( msg )
2021-09-16 22:20:25 +00:00
if msg == "" {
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , "we got an empty body after all we've done" )
return true
}
2021-05-19 15:46:23 +00:00
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , msg )
return true
}
}
func ( p * RestPlugin ) handleErr ( err error , r bot . Request ) bool {
if err != nil {
2021-09-03 13:43:37 +00:00
p . b . Send ( r . Conn , bot . Message , r . Msg . Channel , fmt . Sprintf ( "Error: %s" , err ) )
2021-05-19 15:46:23 +00:00
return true
}
return false
}