限速HTTP请求(通过http.HandlerFunc中间件)

I'm looking to write a small piece of rate-limiting middleware that:

  1. Allows me to set a sensible rate (say, 10 req/s) per remote IP
  2. Possibly (but it doesn't have to) allow for bursts
  3. Drops (closes?) connections that exceed the rate and returns a HTTP 429

I can then wrap this around authentication routes or other routes that might be vulnerable to brute-force attacks (i.e. password reset URLs using a token that expires, etc.). The chances of someone brute forcing a 16 or 24 byte token are really low, but it doesn't hurt to go that extra step.

I've had a look at https://code.google.com/p/go-wiki/wiki/RateLimiting but am not sure how to reconcile it with http.Request(s). Further, I'm not sure how we'd "track" requests from a given IP over any period of time.

Ideally I'd end up with something like this, noting that I'm behind a reverse proxy (nginx) so we're checking for the REMOTE_ADDR HTTP header rather than using r.RemoteAddr:

// Rate-limiting middleware
func rateLimit(h http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {

        remoteIP := r.Header.Get("REMOTE_ADDR")
        for req := range (what here?) {
            // what here?
            // w.WriteHeader(429) and close the request if it exceeds the limit
            // else pass to the next handler in the chain
            h.ServeHTTP(w, r)
        }
}

// Example routes
r.HandleFunc("/login", use(loginForm, rateLimit, csrf)
r.HandleFunc("/form", use(editHandler, rateLimit, csrf)

// Middleware wrapper, for context
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
    for _, m := range middleware {
        h = m(h)
    }

    return h
}

I'd appreciate some guidance here.

The rate limiting example you've linked to is a general one. It uses range because it gets requests over a channel.

It's a different story with HTTP requests, but there's nothing really complicated here. Note that you don't iterate over a channel of requests, or anything -- your HandlerFunc is called for every incoming request separately.

func rateLimit(h http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        remoteIP := r.Header.Get("REMOTE_ADDR")
        if exceededTheLimit(remoteIP) {
            w.WriteHeader(429)
            // it then returns, not passing the request down the chain
        } else {
            h.ServeHTTP(w, r);
        }
    }       
}

Now, choosing the place to store the rate limit counters is up to you. One solution would be to simply use a global map (don't forget safe concurrent access) that would map IPs to their request counters. However, you would have to be aware of how long ago the requests were made.

Sergio suggested using Redis. Its key-value nature is a perfect fit for simple structures like this and you get expiration for free.

You could store the data in redis. Here's a very useful command that even mentions rate limiting application in its documentation: INCR. Redis will also handle cleanup of old data (via expiration of old keys).

Also, with redis being the rate limiter storage, you can use multiple frontend processes that share this central storage.

Some would argue that going to external process each time is expensive. But password reset page is not a kind of page that absolutely demands best performance. Also, if you place the redis on the same machine, latency should be pretty low.

I have done something simple and similar this morning, I think it could help your case.

package main

import (
    "log"
    "net/http"
    "strings"
    "time"
)

func main() {
    fs := http.FileServer(http.Dir("./html/"))
    http.Handle("/", fs)
    log.Println("Listening..")
    go clearLastRequestsIPs()
    go clearBlockedIPs()
    err := http.ListenAndServe(":8080", middleware(nil))
    if err != nil {
        log.Fatalln(err)
    }
}

// Stores last requests IPs
var lastRequestsIPs []string

// Block IP for 6 hours
var blockedIPs []string

func middleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ipAddr := strings.Split(r.RemoteAddr, ":")[0]
        if existsBlockedIP(ipAddr) {
            http.Error(w, "", http.StatusTooManyRequests)
            return
        }
        // how many requests the current IP made in last 5 mins
        requestCounter := 0
        for _, ip := range lastRequestsIPs {
            if ip == ipAddr {
                requestCounter++
            }
        }
        if requestCounter >= 1000 {
            blockedIPs = append(blockedIPs, ipAddr)
            http.Error(w, "", http.StatusTooManyRequests)
            return
        }
        lastRequestsIPs = append(lastRequestsIPs, ipAddr)

        // Don't cut the chain of middlewares
        if next == nil {
            http.DefaultServeMux.ServeHTTP(w, r)
            return
        }
        next.ServeHTTP(w, r)
    })
}

func existsBlockedIP(ipAddr string) bool {
    for _, ip := range blockedIPs {
        if ip == ipAddr {
            return true
        }
    }
    return false
}

func existsLastRequest(ipAddr string) bool {
    for _, ip := range lastRequestsIPs {
        if ip == ipAddr {
            return true
        }
    }
    return false
}

// Clears lastRequestsIPs array every 5 mins
func clearLastRequestsIPs() {
    for {
        lastRequestsIPs = []string{}
        time.Sleep(time.Minute * 5)
    }
}

// Clears blockedIPs array every 6 hours
func clearBlockedIPs() {
    for {
        blockedIPs = []string{}
        time.Sleep(time.Hour * 6)
    }
}

It's still not precise yet, however, it would help as a simple example of rate limiter. you can improve it by adding requested path, http method and even authentication as factors to decide whether the flow is an attack or not.

Here's my rate limit middleware implementation. It works very nicely as a global rate limiter, or a rate limiter for an individual request. I use it extensively in my apps.

Here is what you get with it:

  • no external dependencies
  • testable
  • configurable
  • adds headers so a client can understand how many requests that have left before they are limited, etc.
  • automatically removes expired data.

First, the implementation:

r := router.New()
stats := stats.New()
r.With(middleware.RateLimit(1, time.Minute * 1, stats)).Post("/contact", c.Contact)

The middleware about will allow one request pet minute when making a POST request to /contact.

Here is the middleware:

package middleware

import (
    "net/http"
    "strconv"
    "time"
)

// Stats is an interface to an underlying hash table/map data
// structure. Implement it however you'd like.
type Stats interface {
    // Reset will reset the map.
    Reset()

    // Add would add "count" to the map at the key of "identifier",
    // and returns an int which is the total count of the value 
    // at that key.
    Add(identifier string, count int) int
}

// RateLimit middleware is a generic rate limiter that can be used in any scenario
// because it allows granular rate limiting for each specific request. Or you can
// set the rate limiter on the entire router group. It's just a HandlerFunc.
func RateLimit(limit int, window time.Duration, stats Stats) func(next http.Handler) http.Handler {
    var windowStart time.Time

    // Clear the rate limit stats after each window.
    ticker := time.NewTicker(window)
    go func() {
        windowStart = time.Now()

        for range ticker.C {
            windowStart = time.Now()
            stats.Reset()
        }
    }()

    return func(next http.Handler) http.Handler {
        h := func(w http.ResponseWriter, r *http.Request) {
            value := int(stats.Add(identifyRequest(r), 1))

            XRateLimitRemaining := limit - value
            if XRateLimitRemaining < 0 {
                XRateLimitRemaining = 0
            }

            w.Header().Add("X-Rate-Limit-Limit", strconv.Itoa(limit))
            w.Header().Add("X-Rate-Limit-Remaining", strconv.Itoa(XRateLimitRemaining))
            w.Header().Add("X-Rate-Limit-Reset", strconv.Itoa(int(window.Seconds()-time.Since(windowStart).Seconds())+1))

            if value >= limit {
                w.WriteHeader(429)
                // Do something else...
            } else {
                next.ServeHTTP(w, r)
            }
        }

        return http.HandlerFunc(h)
    }
}

// identifyRequest gets an identifier from the request context.
func identifyRequest(r *http.Request) string {
    // Identify your request here (get IP address, etc.)
}