|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "crypto/tls" |
| 6 | + "io/ioutil" |
| 7 | + "log" |
| 8 | + "net" |
| 9 | + "net/http" |
| 10 | + "os" |
| 11 | + "strconv" |
| 12 | + "sync" |
| 13 | + "time" |
| 14 | + |
| 15 | + "gopkg.in/yaml.v2" |
| 16 | + |
| 17 | + "github.com/prometheus/client_golang/prometheus" |
| 18 | + "github.com/prometheus/client_golang/prometheus/promhttp" |
| 19 | +) |
| 20 | + |
| 21 | +// BackendConfig represents the configuration for each backend |
| 22 | +type BackendConfig struct { |
| 23 | + Backend string `yaml:"backend"` |
| 24 | + Retries int `yaml:"retries"` |
| 25 | + Delay float32 `yaml:"delay` |
| 26 | + Timeout float32 `yaml:"timeout"` // Timeout in seconds |
| 27 | +} |
| 28 | + |
| 29 | +// Config represents the YAML structure mapping front-facing hosts to their backends |
| 30 | +type Config map[string][]BackendConfig |
| 31 | + |
| 32 | +// ProxyServer handles the proxying logic |
| 33 | +type ProxyServer struct { |
| 34 | + configPath string |
| 35 | + config sync.Map // sync.Map for thread-safe access to the configuration |
| 36 | + queue chan *http.Request // Capped channel acting as the request queue |
| 37 | + workerCount int |
| 38 | + |
| 39 | + // Prometheus metrics |
| 40 | + totalRequests prometheus.Counter |
| 41 | + totalForwarded prometheus.Counter |
| 42 | + totalRetries prometheus.Counter |
| 43 | + totalFailed prometheus.Counter |
| 44 | + totalDropped prometheus.Counter |
| 45 | + totalFailedBodyRead prometheus.Counter |
| 46 | + usedQueueLength prometheus.Gauge |
| 47 | +} |
| 48 | + |
| 49 | +// NewProxyServer initializes a new ProxyServer with Prometheus metrics and worker goroutines |
| 50 | +func NewProxyServer(configPath string, queueSize, workerCount int) *ProxyServer { |
| 51 | + ps := &ProxyServer{ |
| 52 | + configPath: configPath, |
| 53 | + queue: make(chan *http.Request, queueSize), // Capped channel |
| 54 | + workerCount: workerCount, |
| 55 | + // Initialize Prometheus metrics |
| 56 | + totalRequests: prometheus.NewCounter(prometheus.CounterOpts{ |
| 57 | + Name: "proxy_requests_total", |
| 58 | + Help: "Total number of incoming requests", |
| 59 | + }), |
| 60 | + totalForwarded: prometheus.NewCounter(prometheus.CounterOpts{ |
| 61 | + Name: "proxy_forwarded_total", |
| 62 | + Help: "Total number of successfully forwarded requests", |
| 63 | + }), |
| 64 | + totalRetries: prometheus.NewCounter(prometheus.CounterOpts{ |
| 65 | + Name: "proxy_retries_total", |
| 66 | + Help: "Total number of retries", |
| 67 | + }), |
| 68 | + totalFailed: prometheus.NewCounter(prometheus.CounterOpts{ |
| 69 | + Name: "proxy_failed_total", |
| 70 | + Help: "Total number of failed requests", |
| 71 | + }), |
| 72 | + totalDropped: prometheus.NewCounter(prometheus.CounterOpts{ |
| 73 | + Name: "proxy_dropped_total", |
| 74 | + Help: "Total number of dropped requests", |
| 75 | + }), |
| 76 | + totalFailedBodyRead: prometheus.NewCounter(prometheus.CounterOpts{ |
| 77 | + Name: "proxy_failed_body_read_total", |
| 78 | + Help: "Total number of requests with failed body reads", |
| 79 | + }), |
| 80 | + usedQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ |
| 81 | + Name: "proxy_queue_length", |
| 82 | + Help: "Current length of the request queue", |
| 83 | + }), |
| 84 | + } |
| 85 | + |
| 86 | + // Register Prometheus metrics |
| 87 | + prometheus.MustRegister(ps.totalRequests, ps.totalForwarded, ps.totalRetries, ps.totalFailed, ps.totalDropped, ps.totalFailedBodyRead, ps.usedQueueLength) |
| 88 | + |
| 89 | + ps.loadConfig() // Load config initially |
| 90 | + go ps.reloadConfigPeriodically() // Start goroutine to reload config every 30 seconds |
| 91 | + go ps.updateQueueLengthPeriodically() // Start goroutine to update queue length metrics |
| 92 | + |
| 93 | + // Start worker goroutines |
| 94 | + for i := 0; i < ps.workerCount; i++ { |
| 95 | + go ps.worker(i) |
| 96 | + } |
| 97 | + |
| 98 | + return ps |
| 99 | +} |
| 100 | + |
| 101 | +// loadConfig loads the configuration from the YAML file |
| 102 | +func (p *ProxyServer) loadConfig() error { |
| 103 | + data, err := ioutil.ReadFile(p.configPath) |
| 104 | + if err != nil { |
| 105 | + log.Printf("Error reading config file: %v\n", err) |
| 106 | + return err |
| 107 | + } |
| 108 | + |
| 109 | + var newConfig Config |
| 110 | + if err := yaml.Unmarshal(data, &newConfig); err != nil { |
| 111 | + log.Printf("Error parsing config file: %v\n", err) |
| 112 | + return err |
| 113 | + } |
| 114 | + |
| 115 | + // Update sync.Map with new config |
| 116 | + for host, backends := range newConfig { |
| 117 | + p.config.Store(host, backends) |
| 118 | + } |
| 119 | + |
| 120 | + // Clear old hosts in sync.Map |
| 121 | + p.config.Range(func(key, value interface{}) bool { |
| 122 | + found := false |
| 123 | + for host, _ := range newConfig { |
| 124 | + if host == key { |
| 125 | + found = true |
| 126 | + } |
| 127 | + } |
| 128 | + if !found { |
| 129 | + log.Printf("Deleting old host %v\n", key) |
| 130 | + p.config.Delete(key) |
| 131 | + } |
| 132 | + return true |
| 133 | + }) |
| 134 | + /* Printf("New map:") |
| 135 | + p.config.Range(func(key, value interface{}) bool { |
| 136 | + log.Printf("%v\n", key) |
| 137 | + log.Printf("%v\n", value) |
| 138 | + return true |
| 139 | + }) |
| 140 | + */ |
| 141 | + log.Println("Configuration reloaded successfully") |
| 142 | + return nil |
| 143 | +} |
| 144 | + |
| 145 | +// reloadConfigPeriodically reloads the config from the file every 30 seconds |
| 146 | +func (p *ProxyServer) reloadConfigPeriodically() { |
| 147 | + for { |
| 148 | + time.Sleep(30 * time.Second) |
| 149 | + if err := p.loadConfig(); err != nil { |
| 150 | + log.Println("Failed to reload config:", err) |
| 151 | + } |
| 152 | + } |
| 153 | +} |
| 154 | + |
| 155 | +// updateQueueLengthPeriodically updates the queue length every 10 seconds |
| 156 | +func (p *ProxyServer) updateQueueLengthPeriodically() { |
| 157 | + for { |
| 158 | + p.usedQueueLength.Set(float64(len(p.queue))) // Update queue length gauge |
| 159 | + time.Sleep(10 * time.Second) |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +// worker processes requests from the queue |
| 164 | +func (p *ProxyServer) worker(id int) { |
| 165 | + log.Printf("Worker %d started\n", id) |
| 166 | + for req := range p.queue { |
| 167 | + p.proxyRequest(req) // Process each request from the queue |
| 168 | + } |
| 169 | +} |
| 170 | + |
| 171 | +// getBackendsForHost returns the backend configurations for a given host |
| 172 | +func (p *ProxyServer) getBackendsForHost(host string) ([]BackendConfig, bool) { |
| 173 | + if backends, found := p.config.Load(host); found { |
| 174 | + return backends.([]BackendConfig), true |
| 175 | + } |
| 176 | + return nil, false |
| 177 | +} |
| 178 | + |
| 179 | +// proxyRequest forwards the request to the backend with retries based on the configuration |
| 180 | +func (p *ProxyServer) proxyRequest(r *http.Request) { |
| 181 | + // Increment total requests counter |
| 182 | + p.totalRequests.Inc() |
| 183 | + |
| 184 | + host := r.Host // The front-facing host |
| 185 | + backends, found := p.getBackendsForHost(host) |
| 186 | + |
| 187 | + if !found || len(backends) == 0 { |
| 188 | + p.totalFailed.Inc() // Increment failed request counter |
| 189 | + log.Printf("Error host: '%v' not found in config file, droping request\n", host) |
| 190 | + return |
| 191 | + } |
| 192 | + |
| 193 | + var lastErr error |
| 194 | + for _, backend := range backends { |
| 195 | + client := &http.Client{ |
| 196 | + //The timeout includes connection time, any |
| 197 | + // redirects, and reading the response body. The timer remains |
| 198 | + // running after Get, Head, Post, or Do return and will |
| 199 | + // interrupt reading of the Response.Body. |
| 200 | + Timeout: (time.Duration(backend.Timeout) + 1) * time.Second, |
| 201 | + Transport: &http.Transport{ |
| 202 | + MaxIdleConns: 100, |
| 203 | + IdleConnTimeout: 90 * time.Second, |
| 204 | + MaxIdleConnsPerHost: 10, |
| 205 | + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // Skip certificate verification (not recommended in production) |
| 206 | + DialContext: (&net.Dialer{ |
| 207 | + // Timeout is the maximum amount of time a dial will wait for |
| 208 | + // a connect to complete. |
| 209 | + Timeout: time.Duration(backend.Timeout) * time.Second, |
| 210 | + }).DialContext, |
| 211 | + }, |
| 212 | + } |
| 213 | + |
| 214 | + for i := 0; i <= backend.Retries; i++ { |
| 215 | + req, err := http.NewRequest(r.Method, backend.Backend+r.URL.Path, bytes.NewReader([]byte{})) |
| 216 | + if err != nil { |
| 217 | + lastErr = err |
| 218 | + continue |
| 219 | + } |
| 220 | + |
| 221 | + req.Header = r.Header |
| 222 | + resp, err := client.Do(req) |
| 223 | + if err == nil && resp.StatusCode < 400 { |
| 224 | + defer resp.Body.Close() |
| 225 | + // Successfully forwarded request |
| 226 | + p.totalForwarded.Inc() |
| 227 | + return |
| 228 | + } |
| 229 | + |
| 230 | + lastErr = err |
| 231 | + p.totalRetries.Inc() // Increment retries counter |
| 232 | + time.Sleep(time.Duration(backend.Delay) * time.Second) // Small delay before retrying |
| 233 | + } |
| 234 | + } |
| 235 | + |
| 236 | + // If we get here, all backends failed |
| 237 | + p.totalFailed.Inc() // Increment failed requests counter |
| 238 | + log.Printf("All backends failed for host %s: %v\n", host, lastErr) |
| 239 | +} |
| 240 | + |
| 241 | +// handleIncomingRequest queues incoming requests |
| 242 | +func (p *ProxyServer) handleIncomingRequest(w http.ResponseWriter, r *http.Request) { |
| 243 | + // Increment failed body read counter if the body can't be read |
| 244 | + _, err := ioutil.ReadAll(r.Body) |
| 245 | + if err != nil { |
| 246 | + p.totalFailedBodyRead.Inc() |
| 247 | + http.Error(w, "Failed to read body", http.StatusBadRequest) |
| 248 | + return |
| 249 | + } |
| 250 | + |
| 251 | + // Try to add the request to the queue |
| 252 | + select { |
| 253 | + case p.queue <- r: |
| 254 | + //log.Printf(w, "Request queued\n") |
| 255 | + default: |
| 256 | + p.totalDropped.Inc() // Increment dropped requests counter |
| 257 | + http.Error(w, "Queue is full", http.StatusServiceUnavailable) |
| 258 | + } |
| 259 | +} |
| 260 | + |
| 261 | +func getEnv(key string, defaultValue string) string { |
| 262 | + if value, exists := os.LookupEnv(key); exists { |
| 263 | + return value |
| 264 | + } |
| 265 | + return defaultValue |
| 266 | +} |
| 267 | + |
| 268 | +func main() { |
| 269 | + queueSize, _ := strconv.Atoi(getEnv("QUEUE_SIZE", "100")) |
| 270 | + workerCount, _ := strconv.Atoi(getEnv("WORKER_COUNT", "5")) |
| 271 | + listenAddress := getEnv("LISTEN_ADDRESS", ":8080") |
| 272 | + metricsPort := getEnv("METRICS_PORT", ":9091") |
| 273 | + configPath := getEnv("CONFIG_PATH", "/etc/backends.yaml") |
| 274 | + |
| 275 | + // Create a new proxy server with the path to the YAML config |
| 276 | + proxy := NewProxyServer(configPath, queueSize, workerCount) // queue size = 100, worker count = 5 |
| 277 | + |
| 278 | + // HTTP handler for incoming requests |
| 279 | + http.HandleFunc("/", proxy.handleIncomingRequest) |
| 280 | + |
| 281 | + // Start the Prometheus metrics server on a separate port |
| 282 | + go func() { |
| 283 | + metricsMux := http.NewServeMux() |
| 284 | + metricsMux.Handle("/metrics", promhttp.Handler()) |
| 285 | + log.Printf("Prometheus metrics server listening on %s\n", metricsPort) |
| 286 | + log.Fatal(http.ListenAndServe(metricsPort, metricsMux)) |
| 287 | + }() |
| 288 | + |
| 289 | + // Start the proxy server |
| 290 | + log.Printf("Proxy server is listening on %s\n", listenAddress) |
| 291 | + log.Fatal(http.ListenAndServe(listenAddress, nil)) |
| 292 | +} |
0 commit comments