asgi rate limiting with a sliding window

asgi rate limiting with a sliding window

I recently needed to limit the number of calls on some endpoints of a FastAPI based application.

The asgi landscape has some pretty nice characteristics, one of them being well explained in this post [1], I quote it here:

Well, because ASGI is an abstraction which allows telling the context we are in, and to receive and send data at any time, there's this idea that ASGI can be used not only between servers and apps, but really at any point in the stack.

In that regard, this package asgi-ratelimit [2] embraces the concept nicely.

This rate limiter is indeed a middleware, its configuration is simple, powerful and it is easily extendable and flexible.

        r"^/towns": [Rule(second=1, group="default"), Rule(group="admin")],
        r"^/forests": [Rule(minute=1, group="default"), Rule(group="admin")],

In the process of testing it, I wanted to use the RedisBackend for different time windows, specifically I wanted to allow those endpoints to be hit a decreasing number of times for an increasing window time.

In the context of the RateLimitMiddleware that would translate to a config with the following settings: config={r"^/multiple": Rule(second=5, minute=100, hour=1000)}

I quickly realized that multiple windows were not well supported, so I came up with this SlidingRedisBackend that can enforce the previous rule.

class SlidingRedisBackend(BaseBackend):
    def __init__(
        host: str = "localhost",
        port: int = 6379,
        db: int = 0,
        password: str = None,
    ) -> None:
        self._redis = StrictRedis(host=host, port=port, db=db, password=password)
        self.sliding_function = self._redis.register_script(SLIDING_WINDOW_SCRIPT)

    async def get_limits(self, path: str, user: str, rule: Rule) -> bool:
        epoch = time.time()
        ruleset = rule.ruleset(path, user)
        keys = list(ruleset.keys())
        args = [epoch, json.dumps(ruleset)]
        # quoted_args = [f"'{a}'" for a in args]
        # cli = f"redis-cli --ldb --eval /tmp/script.lua {' '.join(keys)} , {' '.join(quoted_args)}"
        # logger.debug(cli)
        r = await self.sliding_function.execute(keys=keys, args=args)
        # from tests.backends.test_redis import logger
        # logger.debug(f"{epoch} {r} : {all(r)}")
        return all(r)

In this implementation, if get_limits returns True then the request is accepted, if not it returns a nice 429.
The acute reader will also notice a commented cli variable, it proved very helpful in step debugging the lua script:

  1. put a breakpoint after it,
  2. run redis inside a docker container with the /tmp/script.lua you need to debug mounted as a volume
  3. copy-paste the cli command in a shell in the container and you're step debugging the lua script step by step ! (see here for a nice overview of this process [3])

Note that there is just one redis call here which executes the LUA script SLIDING_WINDOW_SCRIPT:

-- Set variables from arguments
local now = tonumber(ARGV[1])
local ruleset = cjson.decode(ARGV[2])
-- ruleset looks like this:
-- {key: [limit, window_size], ...}
local scores = {}
for i, pgname in ipairs(KEYS) do
    -- we remove keys older than now - window_size
    local clearBefore = now - ruleset[pgname][2]'ZREMRANGEBYSCORE', pgname, 0, clearBefore)
    -- we get the count
    local amount ='ZCARD', pgname)
    -- we add to sorted set if allowed ie the amount < limit
    if amount < ruleset[pgname][1] then'ZADD', pgname, now, now)
    -- cleanup, this expires the whole set in window_size secs'EXPIRE', pgname, ruleset[pgname][2])
    -- calculate the remaining amount of requests. If >= 0 then request for that window is allowed
    scores[i] = ruleset[pgname][1] - amount
return scores

The script I came up is just a mixup of 2 of those 2 excellent blog posts (see [4] and [5])

The idea is the following: on each api call, a key identifies the user calling it, in asgi_rate_limiter it takes the following shape:
path:user:name where name is one of <second|minute|hour|day|month>

So should we take the config from the beginning as an example, when user1 wants to access this /multiple endpoint (note this can match several urls as it's a pattern) (config={r"^/multiple": Rule(second=5, minute=100, hour=1000)}) we would set on every api call three keys /multiple:user1:second ,/multiple:user1:minute and /multiple:user1:hour

For each redis key, the corresponding value will be an ordered set of timestamps, and the number of items in the ordered set will determine the amount of calls we made on that key.

In order to do that we will loop over each key and do 4 things, in order:

  1. remove from this sorted set of timestamps the values that are older than clearBefore = now - window_size : it's a sliding window and they should not count against the number of calls left. It's done with the ZREMRANGEBYSCORE 0 clearBefore [6] command (now looking at the docs I realize it's deprecated !) which will returns all elements of the set between 0 and clearBefore
  1. get the current count in the ordered set with ZCARD [7]
  1. add the current timestamp to the set if we are below the limit, with ZADD [8]
  1. technically not required, set an expiration of length window_size to the key so that we save space

Once all that is done, we calculate a score wich equals to the limit minus the amout, and we return the array of scores for each key.

The rest of the python script will allow the request if all scores are strictly positive.

Note that in this implementation a blocked actions still counts against you, I'd be interested in seeing if this can be done without.

Sources I liked during this quick journey, in no particular order ([9], [10], [11], [12])

  1. ↩︎

  2. ↩︎

  3. ↩︎

  4. ↩︎

  5. ↩︎

  6. ↩︎

  7. ↩︎

  8. ↩︎

  9. ↩︎

  10. ↩︎

  11. ↩︎

  12. ↩︎