first commit

This commit is contained in:
Mateusz Gruszczyński
2026-05-02 20:45:07 +02:00
commit 7922e2ad93
9 changed files with 871 additions and 0 deletions

View File

@@ -0,0 +1,466 @@
package main
import (
"bufio"
"bytes"
"crypto/subtle"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
ListenAddr string
Token string
TargetNetwork string
TargetAddress string
TargetURI string
AllowNet string
ReadTimeout time.Duration
WriteTimeout time.Duration
DialTimeout time.Duration
MaxHeaderBytes int
MaxContentBytes int
}
type AllowRule struct {
any bool
ip net.IP
net *net.IPNet
}
func main() {
cfg := loadConfig()
if err := cfg.validate(); err != nil {
log.Fatalf("config error: %v", err)
}
allowRule, err := parseAllowRules(cfg.AllowNet)
if err != nil {
log.Fatalf("ALLOW_NET error: %v", err)
}
ln, err := net.Listen("tcp", cfg.ListenAddr)
if err != nil {
log.Fatalf("listen error on %s: %v", cfg.ListenAddr, err)
}
defer ln.Close()
log.Printf("rtorrent-scgi-proxy listening=%s target=%s:%s target_uri=%s allow=%s",
cfg.ListenAddr, cfg.TargetNetwork, cfg.TargetAddress, cfg.TargetURI, cfg.AllowNet)
for {
conn, err := ln.Accept()
if err != nil {
log.Printf("accept error: %v", err)
continue
}
go handleConn(conn, cfg, allowRule)
}
}
func loadConfig() Config {
return Config{
ListenAddr: getenv("LISTEN_ADDR", "127.0.0.1:5050"),
Token: os.Getenv("TOKEN"),
TargetNetwork: getenv("TARGET_NETWORK", "tcp"),
TargetAddress: getenv("TARGET_ADDRESS", "127.0.0.1:5000"),
TargetURI: getenv("TARGET_URI", "/RPC2"),
AllowNet: getenv("ALLOW_NET", "127.0.0.1"),
ReadTimeout: durationEnv("READ_TIMEOUT", 15*time.Second),
WriteTimeout: durationEnv("WRITE_TIMEOUT", 30*time.Second),
DialTimeout: durationEnv("DIAL_TIMEOUT", 5*time.Second),
MaxHeaderBytes: intEnv("MAX_HEADER_BYTES", 64*1024),
MaxContentBytes: intEnv("MAX_CONTENT_BYTES", 10*1024*1024),
}
}
func (c Config) validate() error {
if c.Token == "" {
return errors.New("TOKEN is required")
}
if strings.Contains(c.Token, "/") || strings.ContainsAny(c.Token, "\x00\r\n") {
return errors.New("TOKEN must not contain slash, NUL or newlines")
}
if c.TargetNetwork != "tcp" && c.TargetNetwork != "unix" {
return errors.New("TARGET_NETWORK must be tcp or unix")
}
if c.TargetAddress == "" {
return errors.New("TARGET_ADDRESS is required")
}
if !strings.HasPrefix(c.TargetURI, "/") {
return errors.New("TARGET_URI must start with /")
}
if c.MaxHeaderBytes < 1024 || c.MaxHeaderBytes > 1024*1024 {
return errors.New("MAX_HEADER_BYTES must be between 1024 and 1048576")
}
if c.MaxContentBytes < 0 || c.MaxContentBytes > 128*1024*1024 {
return errors.New("MAX_CONTENT_BYTES must be between 0 and 134217728")
}
return nil
}
func handleConn(client net.Conn, cfg Config, allow AllowRules) {
defer client.Close()
remoteIP, _, err := net.SplitHostPort(client.RemoteAddr().String())
if err != nil || !allow.Allows(net.ParseIP(remoteIP)) {
writeSimpleResponse(client, "403 Forbidden", "source ip not allowed\n")
log.Printf("blocked remote=%s", client.RemoteAddr())
return
}
_ = client.SetReadDeadline(time.Now().Add(cfg.ReadTimeout))
br := bufio.NewReader(client)
headersRaw, err := readNetstring(br, cfg.MaxHeaderBytes)
if err != nil {
writeSimpleResponse(client, "400 Bad Request", "invalid scgi netstring\n")
log.Printf("netstring error remote=%s err=%v", client.RemoteAddr(), err)
return
}
headers, err := parseSCGIHeaders(headersRaw)
if err != nil {
writeSimpleResponse(client, "400 Bad Request", "invalid scgi headers\n")
log.Printf("header error remote=%s err=%v", client.RemoteAddr(), err)
return
}
cl, err := parseContentLength(headers["CONTENT_LENGTH"], cfg.MaxContentBytes)
if err != nil {
writeSimpleResponse(client, "400 Bad Request", "invalid content length\n")
return
}
body := make([]byte, cl)
if _, err := io.ReadFull(br, body); err != nil {
writeSimpleResponse(client, "400 Bad Request", "could not read body\n")
return
}
token, err := extractToken(headers["REQUEST_URI"])
if err != nil || !constantTimeEqual(token, cfg.Token) {
writeSimpleResponse(client, "403 Forbidden", "invalid token\n")
log.Printf("invalid token remote=%s", client.RemoteAddr())
return
}
rewritten := cloneMap(headers)
rewritten["REQUEST_URI"] = cfg.TargetURI
rewritten["DOCUMENT_URI"] = cfg.TargetURI
rewritten["SCRIPT_NAME"] = cfg.TargetURI
rewritten["PATH_INFO"] = ""
rewritten["QUERY_STRING"] = ""
outReq, err := buildSCGIRequest(rewritten, body)
if err != nil {
writeSimpleResponse(client, "500 Internal Server Error", "could not build upstream request\n")
log.Printf("build request error: %v", err)
return
}
upstream, err := net.DialTimeout(cfg.TargetNetwork, cfg.TargetAddress, cfg.DialTimeout)
if err != nil {
writeSimpleResponse(client, "502 Bad Gateway", "upstream connect failed\n")
log.Printf("upstream dial error target=%s:%s err=%v", cfg.TargetNetwork, cfg.TargetAddress, err)
return
}
defer upstream.Close()
_ = upstream.SetDeadline(time.Now().Add(cfg.WriteTimeout))
if _, err := upstream.Write(outReq); err != nil {
writeSimpleResponse(client, "502 Bad Gateway", "upstream write failed\n")
log.Printf("upstream write error: %v", err)
return
}
_ = client.SetWriteDeadline(time.Now().Add(cfg.WriteTimeout))
if _, err := io.Copy(client, upstream); err != nil {
log.Printf("copy error remote=%s err=%v", client.RemoteAddr(), err)
return
}
}
func readNetstring(r *bufio.Reader, maxHeaderBytes int) ([]byte, error) {
var lenBuf bytes.Buffer
for {
b, err := r.ReadByte()
if err != nil {
return nil, err
}
if b == ':' {
break
}
if b < '0' || b > '9' {
return nil, fmt.Errorf("invalid netstring length byte: %q", b)
}
if lenBuf.Len() > 10 {
return nil, errors.New("netstring length too long")
}
lenBuf.WriteByte(b)
}
if lenBuf.Len() == 0 {
return nil, errors.New("empty netstring length")
}
n, err := strconv.Atoi(lenBuf.String())
if err != nil {
return nil, err
}
if n < 0 || n > maxHeaderBytes {
return nil, errors.New("netstring payload too large")
}
payload := make([]byte, n)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
trailer, err := r.ReadByte()
if err != nil {
return nil, err
}
if trailer != ',' {
return nil, errors.New("missing netstring trailer comma")
}
return payload, nil
}
func parseSCGIHeaders(payload []byte) (map[string]string, error) {
if len(payload) == 0 {
return nil, errors.New("empty header payload")
}
parts := bytes.Split(payload, []byte{0})
if len(parts) < 3 {
return nil, errors.New("not enough header parts")
}
if len(parts[len(parts)-1]) != 0 {
return nil, errors.New("headers must end with NUL")
}
headers := make(map[string]string)
for i := 0; i < len(parts)-1; i += 2 {
if i+1 >= len(parts)-1 {
return nil, errors.New("odd number of header items")
}
k := string(parts[i])
v := string(parts[i+1])
if k == "" {
return nil, errors.New("empty header name")
}
if strings.ContainsAny(k, "\r\n") {
return nil, fmt.Errorf("invalid header name %q", k)
}
if _, exists := headers[k]; exists {
return nil, fmt.Errorf("duplicate header %q", k)
}
headers[k] = v
}
if headers["CONTENT_LENGTH"] == "" {
return nil, errors.New("missing CONTENT_LENGTH")
}
if headers["SCGI"] != "1" {
return nil, errors.New("missing or invalid SCGI")
}
return headers, nil
}
func parseContentLength(s string, max int) (int, error) {
cl, err := strconv.Atoi(s)
if err != nil || cl < 0 || cl > max {
return 0, errors.New("invalid CONTENT_LENGTH")
}
return cl, nil
}
func buildSCGIRequest(headers map[string]string, body []byte) ([]byte, error) {
h := cloneMap(headers)
h["CONTENT_LENGTH"] = strconv.Itoa(len(body))
h["SCGI"] = "1"
keys := make([]string, 0, len(h))
for k := range h {
if k == "CONTENT_LENGTH" || k == "SCGI" {
continue
}
keys = append(keys, k)
}
sortStrings(keys)
var hb bytes.Buffer
writePair(&hb, "CONTENT_LENGTH", h["CONTENT_LENGTH"])
writePair(&hb, "SCGI", "1")
for _, k := range keys {
v := h[k]
if strings.IndexByte(k, 0) >= 0 || strings.IndexByte(v, 0) >= 0 {
return nil, fmt.Errorf("header contains NUL: %q", k)
}
writePair(&hb, k, v)
}
var out bytes.Buffer
out.WriteString(strconv.Itoa(hb.Len()))
out.WriteByte(':')
out.Write(hb.Bytes())
out.WriteByte(',')
out.Write(body)
return out.Bytes(), nil
}
func writePair(b *bytes.Buffer, k, v string) {
b.WriteString(k)
b.WriteByte(0)
b.WriteString(v)
b.WriteByte(0)
}
func extractToken(uri string) (string, error) {
clean := strings.SplitN(uri, "?", 2)[0]
clean = strings.TrimSpace(clean)
const prefix = "/proxy/"
if !strings.HasPrefix(clean, prefix) {
return "", errors.New("uri must start with /proxy/")
}
token := strings.TrimPrefix(clean, prefix)
if token == "" || strings.Contains(token, "/") || strings.ContainsAny(token, "\x00\r\n") {
return "", errors.New("invalid token path")
}
return token, nil
}
type AllowRules []AllowRule
func parseAllowRules(s string) (AllowRules, error) {
parts := strings.Split(s, ",")
rules := make(AllowRules, 0, len(parts))
for _, part := range parts {
rule, err := parseAllowRule(part)
if err != nil {
return nil, err
}
rules = append(rules, rule)
}
if len(rules) == 0 {
return nil, errors.New("empty allow rules")
}
return rules, nil
}
func parseAllowRule(s string) (AllowRule, error) {
s = strings.TrimSpace(s)
if s == "" {
return AllowRule{}, errors.New("empty allow rule")
}
if s == "*" || s == "0.0.0.0/0" || s == "::/0" {
return AllowRule{any: true}, nil
}
if strings.Contains(s, "/") {
_, n, err := net.ParseCIDR(s)
if err != nil {
return AllowRule{}, err
}
return AllowRule{net: n}, nil
}
ip := net.ParseIP(s)
if ip == nil {
return AllowRule{}, fmt.Errorf("invalid IP: %s", s)
}
return AllowRule{ip: ip}, nil
}
func (rs AllowRules) Allows(ip net.IP) bool {
for _, r := range rs {
if r.Allows(ip) {
return true
}
}
return false
}
func (r AllowRule) Allows(ip net.IP) bool {
if ip == nil {
return false
}
if r.any {
return true
}
if r.ip != nil {
return r.ip.Equal(ip)
}
if r.net != nil {
return r.net.Contains(ip)
}
return false
}
func constantTimeEqual(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
func cloneMap(in map[string]string) map[string]string {
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func writeSimpleResponse(w net.Conn, status, body string) {
resp := fmt.Sprintf("Status: %s\r\nContent-Type: text/plain\r\nContent-Length: %d\r\n\r\n%s", status, len(body), body)
_, _ = io.WriteString(w, resp)
}
func getenv(key, def string) string {
val := strings.TrimSpace(os.Getenv(key))
if val == "" {
return def
}
return val
}
func durationEnv(key string, def time.Duration) time.Duration {
val := strings.TrimSpace(os.Getenv(key))
if val == "" {
return def
}
d, err := time.ParseDuration(val)
if err != nil {
log.Fatalf("invalid %s=%q: %v", key, val, err)
}
return d
}
func intEnv(key string, def int) int {
val := strings.TrimSpace(os.Getenv(key))
if val == "" {
return def
}
i, err := strconv.Atoi(val)
if err != nil {
log.Fatalf("invalid %s=%q: %v", key, val, err)
}
return i
}
func sortStrings(s []string) {
for i := 1; i < len(s); i++ {
v := s[i]
j := i - 1
for j >= 0 && s[j] > v {
s[j+1] = s[j]
j--
}
s[j+1] = v
}
}

View File

@@ -0,0 +1,177 @@
package main
import (
"bufio"
"io"
"net"
"strings"
"testing"
"time"
)
func TestParseSCGIHeaders(t *testing.T) {
raw := []byte("CONTENT_LENGTH\x000\x00SCGI\x001\x00REQUEST_URI\x00/proxy/token\x00")
h, err := parseSCGIHeaders(raw)
if err != nil {
t.Fatalf("parseSCGIHeaders returned error: %v", err)
}
if h["CONTENT_LENGTH"] != "0" || h["SCGI"] != "1" || h["REQUEST_URI"] != "/proxy/token" {
t.Fatalf("unexpected headers: %#v", h)
}
}
func TestBuildSCGIRequestRoundTrip(t *testing.T) {
body := []byte("<methodCall/>")
req, err := buildSCGIRequest(map[string]string{
"CONTENT_LENGTH": "999",
"SCGI": "1",
"REQUEST_URI": "/RPC2",
}, body)
if err != nil {
t.Fatalf("buildSCGIRequest returned error: %v", err)
}
br := bufio.NewReader(strings.NewReader(string(req)))
raw, err := readNetstring(br, 4096)
if err != nil {
t.Fatalf("readNetstring returned error: %v", err)
}
h, err := parseSCGIHeaders(raw)
if err != nil {
t.Fatalf("parseSCGIHeaders returned error: %v", err)
}
if h["CONTENT_LENGTH"] != "13" || h["REQUEST_URI"] != "/RPC2" {
t.Fatalf("unexpected rewritten headers: %#v", h)
}
gotBody, err := io.ReadAll(br)
if err != nil {
t.Fatalf("ReadAll returned error: %v", err)
}
if string(gotBody) != string(body) {
t.Fatalf("unexpected body: %q", gotBody)
}
}
func TestAllowRules(t *testing.T) {
rules, err := parseAllowRules("10.0.0.0/8, 192.168.1.10")
if err != nil {
t.Fatalf("parseAllowRules returned error: %v", err)
}
if !rules.Allows(net.ParseIP("10.2.3.4")) {
t.Fatal("CIDR rule should allow 10.2.3.4")
}
if !rules.Allows(net.ParseIP("192.168.1.10")) {
t.Fatal("single IP rule should allow 192.168.1.10")
}
if rules.Allows(net.ParseIP("172.16.0.1")) {
t.Fatal("rules should block 172.16.0.1")
}
}
func TestEndToEndProxy(t *testing.T) {
upstream, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("upstream listen: %v", err)
}
defer upstream.Close()
upstreamDone := make(chan error, 1)
go func() {
conn, err := upstream.Accept()
if err != nil {
upstreamDone <- err
return
}
defer conn.Close()
br := bufio.NewReader(conn)
raw, err := readNetstring(br, 4096)
if err != nil {
upstreamDone <- err
return
}
h, err := parseSCGIHeaders(raw)
if err != nil {
upstreamDone <- err
return
}
if h["REQUEST_URI"] != "/RPC2" {
upstreamDone <- unexpectedErr("REQUEST_URI", h["REQUEST_URI"])
return
}
cl, err := parseContentLength(h["CONTENT_LENGTH"], 1024)
if err != nil {
upstreamDone <- err
return
}
if _, err := io.CopyN(io.Discard, br, int64(cl)); err != nil {
upstreamDone <- err
return
}
_, err = io.WriteString(conn, "Status: 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 2\r\n\r\nok")
upstreamDone <- err
}()
proxy, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxy.Close()
cfg := Config{
Token: "secret",
TargetNetwork: "tcp",
TargetAddress: upstream.Addr().String(),
TargetURI: "/RPC2",
ReadTimeout: time.Second,
WriteTimeout: time.Second,
DialTimeout: time.Second,
MaxHeaderBytes: 4096,
MaxContentBytes: 1024,
}
rules, err := parseAllowRules("127.0.0.1")
if err != nil {
t.Fatalf("parseAllowRules: %v", err)
}
go func() {
conn, err := proxy.Accept()
if err == nil {
handleConn(conn, cfg, rules)
}
}()
client, err := net.Dial("tcp", proxy.Addr().String())
if err != nil {
t.Fatalf("client dial: %v", err)
}
defer client.Close()
req, err := buildSCGIRequest(map[string]string{
"CONTENT_LENGTH": "0",
"SCGI": "1",
"REQUEST_URI": "/proxy/secret",
}, []byte("ping"))
if err != nil {
t.Fatalf("build request: %v", err)
}
if _, err := client.Write(req); err != nil {
t.Fatalf("client write: %v", err)
}
resp, err := io.ReadAll(client)
if err != nil {
t.Fatalf("client read: %v", err)
}
if !strings.Contains(string(resp), "ok") {
t.Fatalf("unexpected response: %q", resp)
}
if err := <-upstreamDone; err != nil {
t.Fatalf("upstream error: %v", err)
}
}
type unexpectedError string
func (e unexpectedError) Error() string { return string(e) }
func unexpectedErr(field, got string) error {
return unexpectedError(field + "=" + got)
}