Rate Limit Challenge

mattseh

import this
Apr 6, 2009
5,504
71
0
A ~= A
Imagine you have a web app, which is dployed to multiple machines. No state is stored between requests in the actual webapp, although databases etc are of course allowed.

You need to write a function which rate limits access to a given URL within your app, for a specific user. The limits are: Max 5 hits per second, max 10 hits per 20 seconds, max 25 hits per 60 seconds, although this should be simple for the user of the function to change. If any of the limits is exceeded, access is denied.

The programming language doesn’t matter, how would you do this? I’m only interested in the function, which returns true of false, not the actual web app. You can use any third party database / library / whatever. No record of the rate limiting needs to be stored long term.

Here’s my attempt, which I’ll reveal once everyone else who wants to has had a shot:

$ md5 rate_limit.py
MD5 (rate_limit.py) = 0220011c8421748197c620bf333a315a

Bonus points for clarity, efficiency, etc.

Have fun :)
 


Personally, I'd just use a central database that each slave client / server connects to, and uses to ensure they're not over the limit.

Or is this more of P2P type of thing?

EDIT: Maybe just a quick daemon running on a separate port to get rid of HTTP. Each client connects, and keeps connection open for duration of session. Client sends URL for each request they want to make, and 1 or 0 is given as response. Daemon keeps track of it all centrally via one of the many "quick access" hashes available (ie. Berkeley DB).
 
Not P2P, a central database is a valid solution, is it efficient though? I think it depends on the implementation, how would you do it?
 
See my edit above for possible solution.

Too broad of a question though. For example, can you partition / reserve URLs for each slave? If so, that makes things nice and easy. Just make a little check-in / check-out system. Have a pending list of URLs on the master that need to the processed, which all slaves have access to. Before scraping a URL, slave bounces a request off the master to check-out that URL, and ensure it isn't being used by another slave. Then it's up to the slave to adhere to the timing limits, as they'll be the only ones hitting that URL.

Then redundancy so if slave A can't connect to URL XYZ, it checks the URL back in to a pending list that the other slaves can pick up and try. Once X slaves can't connect to the URL, then trash it.
 
This isn't for scraping, this is for serving a web app. That logic makes sense for scraping for sure though.
 
Keeping it simple and only focusing on the 60 second constraint window, the most naive solution that comes to mind is with Redis and using EXPIRE as the ttl clamp on the window.

Let's say requests are coming in with an `access_token` in their header.

- As requests come in, a middleware function plucks `access_token` from the request.
- If 25 <= (count (redis/query "KEYS {access_token}:*")), then respond with a 429 Too Many Requests.
- Else, let key = "{access_token}:{timestamp}", then pipeline (redis/query "SET {key} true") and (redis/query "EXPIRE {key} 60") before passing the request to an app handler.

(Not a challenge entry ^_^)
 
This isn't for scraping, this is for serving a web app. That logic makes sense for scraping for sure though.

In that case, shared directory on the master that contains a BerkelyDB or whatever hash / DB, and each slave mounts to the directory via NFS or whatever. All slaves are running off the same hash. If the load gets too strenuous, partition the hashes by the last digit of the user ID#, or whatever.

To be honest though, I'd actually contact one of the qualified server admins I know, and ask for their input before doing anything. This sounds like something that's better handled on the server-end, not software-end, and I'm assuming they would have a better solution than I do.
 
i-know-some-of-these-words-674x505.png
 
Ok, here's how I did it:

Code:
import redis
import time

def rate_limit_check(r, key, limits):
    period_lengths = [_[0] for _ in sorted(limits.items())]
    period_limits = [_[1] for _ in sorted(limits.items())]
    pipe = r.pipeline()
    for period_length in period_lengths:
        current_period = int(time.time() / period_length)
        redis_key = 'rate_limit:{key}:{period_length}:{current_period}'.format(key=key, period_length=period_length, current_period=current_period)
        pipe.incr(redis_key).expire(redis_key, period_length*3)
    return not any(hits > period_limit for period_limit, hits in zip(period_limits, pipe.execute()[::2]))


if __name__ == '__main__':
    r = redis.Redis()
    print rate_limit_check(r, '127.0.0.1', {1: 3, 10: 5})
There's only one roundtrip to Redis, and Redis cleans up due to expires on the keys. In this example it's max 3 requests per second, 5 requests every 10 seconds, and it's checking against an IP address, although it can work with any string you can construct.
 
would a real world solution not rate limit on both IP and api key?
 
Ok, here's how I did it:

Code:
import redis
import time

def rate_limit_check(r, key, limits):
    period_lengths = [_[0] for _ in sorted(limits.items())]
    period_limits = [_[1] for _ in sorted(limits.items())]
    pipe = r.pipeline()
    for period_length in period_lengths:
        current_period = int(time.time() / period_length)
        redis_key = 'rate_limit:{key}:{period_length}:{current_period}'.format(key=key, period_length=period_length, current_period=current_period)
        pipe.incr(redis_key).expire(redis_key, period_length*3)
    return not any(hits > period_limit for period_limit, hits in zip(period_limits, pipe.execute()[::2]))


if __name__ == '__main__':
    r = redis.Redis()
    print rate_limit_check(r, '127.0.0.1', {1: 3, 10: 5})

i presume this is explained by whitespace but i'm getting:

me@myserver:~$ md5sum mattseh2
27c52789d9548825d4065373bb58b536 mattseh2
me@myserver:~$ vi mattseh2
me@myserver:~$ md5sum mattseh2
c817c10a05d6a76c18d477017cfb3bb5 mattseh2

i added a line break at the end for the second checksum

yep, i'm that sad that i checked!
 
Here's a Python wrapper around a Lua script I modified for this use case (original is here: https://gist.github.com/josiahcarlson/3cb0561707fd2ca1f176).

You can pass in a list of identifiers and conditions, all identifiers are checked against all conditions provided.
Allows for simultaneous second/minute/hour/day limits (or any combination), with zero race conditions.

Also includes a "weight" argument so you can implement logic around more expensive calls if needed, (facebook does something similar, described here: https://developers.facebook.com/docs/reference/ads-api/api-rate-limiting)

Code:
import time
import redis


def over_limit_sliding(conn, base_keys, single=5, twenty=10, sixty=25, weight=1):
    '''
    Will return whether the caller is over any of their limits. Uses a sliding
    schedule with millisecond resolution.

    Arguments:
        conn - a Redis connection object
        base_keys - how you want to identify the caller, pass a list of
                    identifiers
        single, twenty, sixty - limits for each resolution
        weight - how much does this "call" count for
    '''
    limits = [single, twenty, sixty, weight, int(time.time())]
    return bool(over_limit_sliding_lua(conn, keys=base_keys, args=limits))

  
def _script_load(script):
    '''
    Used because the API for the Python Lua scripting support is awkward.
    '''
    sha = [None]
    def call(conn, keys=[], args=[], force_eval=False):
        if not force_eval:
            if not sha[0]:
                sha[0] = conn.execute_command(
                    "SCRIPT", "LOAD", script, parse="LOAD")
            try:
                return conn.execute_command(
                    "EVALSHA", sha[0], len(keys), *(keys+args))
            except redis.exceptions.ResponseError as msg:
                if not msg.args[0].startswith("NOSCRIPT"):
                    raise
        return conn.execute_command(
            "EVAL", script, len(keys), *(keys+args))
    return call


over_limit_sliding_lua = _script_load('''
local slice = {1, 20, 60}
local precision = {5, 10, 25}
local dkeys = {'m', 'h', 'd'}
local ts = tonumber(table.remove(ARGV))
local weight = tonumber(table.remove(ARGV))
local fail = false

-- Make two passes, the first to clean out old data and make sure there is
-- enough available resources, the second to update the counts.
for _, ready in ipairs({false, true}) do
    -- iterate over all of the limits provided
    for i = 1, math.min(#ARGV, #slice) do
        local limit = tonumber(ARGV[i])

        -- make sure that it is a limit we should check
        if limit > 0 then
            -- calculate the cutoff times and suffixes for the keys
            local cutoff = ts - slice[i]
            local curr = '' .. (precision[i] * math.floor(ts / precision[i]))
            local suff = ':' .. dkeys[i]
            local suff2 = suff .. ':l'

            -- check each key to verify it is not above the limit
            for j, k in ipairs(KEYS) do
                local key = k .. suff
                local key2 = k .. suff2

                if ready then
                    -- if we get here, our limits are fine
                    redis.call('incrby', key, weight)
                    local oldest = redis.call('lrange', key2, '0', '1')
                    if oldest[2] == curr then
                        redis.call('ltrim', key2, 0, -3)
                        redis.call('rpush', key2, weight + tonumber(oldest[1]), oldest[2])
                    else
                        redis.call('rpush', key2, weight, curr)
                    end
                    redis.call('expire', key, slice[i])
                    redis.call('expire', key2, slice[i])

                else
                    -- get the current counted total
                    local total = tonumber(redis.call('get', key) or '0')

                    -- only bother to clean out old data on our first pass through,
                    -- we know the second pass won't do anything
                    while total + weight > limit do
                        local oldest = redis.call('lrange', key2, '0', '1')
                        if #oldest == 0 then
                            break
                        end
                        if tonumber(oldest[2]) <= cutoff then
                            total = tonumber(redis.call('incrby', key, -tonumber(oldest[1])))
                            redis.call('ltrim', key2, '2', '-1')
                        else
                            break
                        end
                    end

                    fail = fail or total + weight > limit
                end
            end
        end
    end
    if fail then
        break
    end
end

return fail
''')


def test(count=300):
    import uuid
    keys = [str(uuid.uuid4())]
    c = redis.Redis()
    
    t = time.time()
    total = 0
    for i in xrange(count):
        isblocked = over_limit_sliding(c, keys, single=5, twenty=10, sixty=25, weight=1)
        # Uncomment next line to test "sliding" accuracy/precision.
        time.sleep(1)
        if not isblocked:
            total += 1
    print "Sliding sequential performance LUA:", count / (time.time() - t)
    print "{} total requests".format(total)
    
    
if __name__ == '__main__':
    test()
This thread is good timing for me because I'll need to implement rate limiting for a project I'm working on, so I was curious about performance & accuracy of different methods (would still like to test against zset/hash methods).

Tested the above vs your script (both using your original limits):

1 attempt/sec for 30 seconds
Code:
(dat-venv)action@storms-end-django-80046:~$ python ratelimit.py
Sliding sequential performance Mattseh: 0.998495635941
19 total requests
Sliding sequential performance LUA: 0.998543487603
20 total requests
1 attempt/sec for 5 minutes
Code:
(dat-venv)action@storms-end-django-80046:~$ python ratelimit.py
Sliding sequential performance Mattseh: 0.998520096054
95 total requests
Sliding sequential performance LUA: 0.998607614753
134 total requests
Seems your version has better performance (which makes sense; pure python, one roundtrip) but at the expense of precision/accuracy.

Didn't get a chance to get into it yet but this RateLimiter class looks like it might be useful too (or at least parts of it): https://github.com/DomainTools/rate-limit/blob/master/ratelimit.py.