server/server.go

155 lines
3 KiB
Go

package main
import (
"bufio"
"bytes"
"database/sql"
"flag"
"fmt"
"log"
"os"
"strconv"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/valyala/fasthttp"
)
const pathsep = string(os.PathSeparator)
var (
bind = flag.String("bind", "127.0.0.1:4001", "Where to listen for connections")
proxy = flag.Bool("proxy", true, "Whether server is behind a reverse proxy")
verbose = flag.Bool("verbose", false, "Whether to log all URLs")
dbpath = flag.String("dbpath", "archive.db", "Path to SQLite3 DB generated by mitmproxy addon")
storage = flag.String("storage", "storage", "Path to archived responses storage")
)
var removeHeaders = map[string]bool{
"date": true,
"expires": true,
"server": true,
}
var conn *sql.DB
var stmt *sql.Stmt
func main() {
flag.Parse()
db, err := sql.Open("sqlite3", *dbpath)
if err != nil {
return
}
conn = db
defer db.Close()
sel_stmt, err := db.Prepare("select id, code from data where method = ? and url = ? limit 1")
if err != nil {
return
}
stmt = sel_stmt
defer sel_stmt.Close()
err1 := fasthttp.ListenAndServe(*bind, handler)
if err1 != nil {
log.Fatalln(err1)
}
}
func handler(ctx *fasthttp.RequestCtx) {
// -- find in DB and read id+code
uri := ctx.URI()
var scheme []byte
if *proxy {
scheme = ctx.Request.Header.Peek("X-Forwarded-Proto")
} else {
scheme = uri.Scheme()
}
host, port := parseHost(
uri.Host(),
bytes.Equal(scheme, []byte("https")),
)
urlToSearch := fmt.Sprintf(
"%s://%s:%s%s",
scheme,
host,
port,
ctx.RequestURI(),
)
row := stmt.QueryRow(
string(ctx.Method()),
urlToSearch,
)
if *verbose {
log.Println(urlToSearch)
}
var id int
var code int
err := row.Scan(&id, &code)
if err == sql.ErrNoRows {
ctx.Response.SetStatusCode(404)
return
} else if err != nil {
sendError(ctx, "Unable to fetch row: ", err)
return
}
// -- set status code
ctx.Response.SetStatusCode(code)
// -- find in FS and read headers+body
path := *storage + pathsep + strconv.Itoa(id)
fh, err := os.Open(path + pathsep + "headers")
if err != nil {
sendError(ctx, "Unable to read headers: ", err)
return
}
sc := bufio.NewScanner(fh)
for sc.Scan() {
header := strings.SplitN(sc.Text(), ": ", 2)
name, value := strings.ToLower(header[0]), header[1]
if removeHeaders[name] {
continue
}
ctx.Response.Header.Add(name, value)
}
fh.Close()
fb, err := os.Open(path + pathsep + "body")
if err != nil {
sendError(ctx, "Unable to read body: ", err)
return
}
ctx.Response.SetBodyStream(fb, -1)
}
func parseHost(host []byte, https bool) ([]byte, []byte) {
idx := bytes.LastIndex(host, []byte(":"))
var resHost, port []byte
if idx != -1 {
resHost = host[:idx]
port = host[idx+1:]
} else {
resHost = host
if https {
port = []byte("443")
} else {
port = []byte("80")
}
}
return resHost, port
}
func sendError(ctx *fasthttp.RequestCtx, msg string, err error) {
ctx.Response.SetStatusCode(500)
ctx.Response.SetBodyString(msg + err.Error())
}