首页 > 经验记录 > 探秘分布式解决方案: 分布式限流——Redis版分布式信号量原理 (附RedisTemplate具体实现代码)

探秘分布式解决方案: 分布式限流——Redis版分布式信号量原理 (附RedisTemplate具体实现代码)

 
关于限流这种机制呢也算是老生常谈了, 毕竟在业务开发中实在是很多地方都会用到。比如第三方接口调用限制、并发访问数控制等…
 
而具体的限流算法在单机中很容易就可以实现,  在java的世界里既有开源库Guava的RateLimiter,  也有JUC中自带的 Semaphore、BlockingQueue等。拿来随手就可以使用。
 
那么在分布式场景中,  限流就变得不是那么容易了。 在这种环境上想实现限流本质上是在实现一种多个进程之间的协同工作机制。 必须得依靠一个可靠的协调中心才行,这一般都会选一种中间件来实现。
而 Redis 其实就是一个适合这种场景的中间件,既快,又有强大的数据结构对各种限流算法提供支持。那么我这里就来基于Redis实现一种简单粗暴又好用的限流方案——信号量(Semaphore)
 

关于信号量

这里引用一段维基百科的定义

信号量(英语:semaphore)又称为信号标,是一个同步对象,用于保持在0至指定最大值之间的一个计数值。当线程完成一次对该semaphore对象的等待(wait)时,该计数值减一;当线程完成一次对semaphore对象的释放(release)时,计数值加一。当计数值为0,则线程等待该semaphore对象不再能成功直至该semaphore对象变成signaled状态。semaphore对象的计数值大于0,为signaled状态;计数值等于0,为nonsignaled状态.

 
其实白话说起来很简单,信号量就是可以被多个线程同时持有一种同步对象,比如我设置一个值为5的计数信号量,那么现在有十个线程来获取他就只会有五个可以成功,剩下那五个则获取失败。
所以说如果有个计数信号量定义的值是1,那么他其实就等同于 mutex (互斥锁)
 
 

实现的基本思路

既然知道信号量本质是一种锁,那么对于信号量需要拥有的效果自然就有了思路

  • 拥有获取、释放的机制
  • 需要知道是哪个客户端获取到了信号量
  • 获取到信号量之后, 不能因为客户端的崩溃导致无法释放

 
 
对于Redis来说,可以使用ZSet来实现这些效果。ZSet 是不可重复的有序集合, 内部每个元素都拥有一个属于自己的分数 (score)
那么我们可以将 ZSet 中一条数据视为客户端获取到的信号量, key就是客户端的唯一标识, score 可以设置为客户端获取信号量的时间
这样就能够实现上面所说的几种机制

  • 在ZSet中插入数据即为获取。  删除即为释放。
  • 利用ZSet的有序特性,  可以根据 score 的排名来判断是否成功获取到了信号量
  • 因为 score 存的是客户端获取到信号量的时间, 所以可以约定一个过期时间来对死掉的客户端获取到的信号量进行清除

 
 
值得注意的是,  在 Redis 上实现信号量,如果客户端持有信号量之后由于处理时间太久导致没在规定的超时时间内释放的话, 那么这个持有信号量延时机制, 是需要自己实现的 (守护线程定时更新等)  因为Redis他本身没有提供这种功能的实现, 所以只能自己动手了。  不过像这种需求在使用到分布式信号量的场景中多数不怎么会出现,  所以也可以不管。
 
 

使用 Redis 构建分布式信号量的实现细节

知道了需求和对应的实现思路后,  那么可以来决定一下具体的实现细节
这里就需要对 Redis 的命令有一定的了解,  不过就算不了解也没关系, 反正就这么几个命令。
可以从下面的网址中参考命令的效果
https://redis.io/commands
http://redisdoc.com
http://doc.redisfans.com
 
可行的具体流程

  1. 获取系统当前时间,  因为集合的分数储存的是时间毫秒值, 所以可以通过 ZREMRANGEBYSCORE 清理掉过期的信号量
  2. 使用 ZADD 向集合中添加代表自身信号量的元素, 分数为当前时间
  3. 通过 ZRANK 得到当前客户端在集合中的排名, 如果在许可证数量的范围内 (即不大于信号量最大持有数量)  即视为成功获取信号量
  4. 如果不在范围内, 比如信号量设置的最大许可数为 5, 自己在集合中的排名是5 (Redis rank 从0开始数) 则视为获取型号量失败, 使用 ZREM 清理数据

 
使用上边说的流程来进行分布式信号量的实现是很好用的,既简单又快速, 很多时候用这个就可以了。
但是它其实还有一个小小的问题。 因为其判定时间的逻辑会存在于客户端中 , 而在不同的主机环境上时间并不一定会完全一致,可能会有个几毫秒的误差,这样子就有可能出现信号量超发的问题。
比如这么一个场景,  在获取最后一个信号量的时候, 客户端A 已经获取到了最后一个信号量,  这个时候客户端B (B的时间比A要慢一点) 也来尝试获取信号量, 那么在B判定分数的时候有可能就会发现自己的排名仍在信号量的许可证最大数范围内, 从而B也拿到了这最后一个信号量。
其实在很多场景中偶尔超标问题不大, 再加上本身也是小概率事件, 所以很多时候可以无视这个问题。 不过既然有这么个问题,那还是可以继续优化下去的,让具体的实现更加公平一点
 
 
优化后的逻辑
知道了会发生超发的原因就是因为在比较排名时用的是保存获取时间的集合,而根据获取时间集合中的排名来判断是有可能和实际排名对不上的。
那么我们就可以想个办法将排名控制逻辑单独拿出来,将其放到一个可靠的地方来实现。这个可靠的地方当然也可以使用Redis啦。
 
可以利用Redis自带的原子性自增 INCR 来获取序列号。  再新增一个有序集合以这个序列号作为score,  这样就可以根据这个排名来进行公平的判定。
而需要做的只是在  获取信号量、获取失败清除数据  这两个步骤中增加对于专门的排名集合的操作。
并且在获取信号量之前清理过期数据时同时清理排名集合中的数据即可,  这里可以使用ZSet的 ZINTERSTORE 取交集并储存来实现。
这样就是一个相对之前的逻辑而言更完备的实现了。 只要你的系统时间差的不是那么多, 那都是可以公平的安全获取的。  如果想要完全保证公平,也可以用锁机制来实现,不过一般来说没必要。
 
 
 

分布式信号量具体实现代码

知道了具体的细节后,就可以进入到编码流程了。
想自己动手实现的可以先按照步骤自己实现。可以用我的代码作为参考。
 
这里就贴上我写的实现: 
 
首先先定义一个信号量的基础信息, 表示一个信号量的元数据。
一般是事先手动配置好的。可以放在配置文件、配置中心、SQL数据库里。

public final class SemaphoreInfo {
    //信号量的名称
    private final String semaphoreName;
    //许可证的数量
    private final int permits;
    //信号量最大持有时间 (过期时间) 单位s
    private final long expire;
    //公平 or 非公平
    private final boolean fair;
    public SemaphoreInfo(String semaphoreName, int permits, long expire) {
        this(semaphoreName, permits, expire, false);
    }
    public SemaphoreInfo(String semaphoreName, int permits, long expire, boolean fair) {
        this.semaphoreName = semaphoreName;
        this.permits = permits;
        this.expire = expire;
        this.fair = fair;
    }
    public String getSemaphoreName() {
        return semaphoreName;
    }
    public int getPermits() {
        return permits;
    }
    public long getExpire() {
        return expire;
    }
    public boolean isFair() {
        return fair;
    }
}

 
 
有了元信息之后, 就可以开始实现具体的需求了
先定义一个接口

public interface DistributedSemaphore {
    /**
     * 尝试获取一个信号量
     *
     * @return true 获取成功, false 获取失败
     */
    boolean tryAcquire();
    /**
     * 释放自己持有的信号量
     */
    void release();
}

 
这个接口的具体实现, RedisTemplate 版
注释比较详细, 就不过多说明了。

public class RedisSemaphore implements DistributedSemaphore {
    private static final String SEMAPHORE_TIME_KEY = "semaphore:time:";
    private static final String SEMAPHORE_OWNER_KEY = "semaphore:owner:";
    private static final String SEMAPHORE_COUNTER_KEY = "semaphore:counter:";
    private final RedisTemplate redisTemplate;
    private final String timeKey;
    private final String ownerKey;
    private final String counterKey;
    //信号量的信息
    private final SemaphoreInfo info;
    //信号量实体
    private final DistributedSemaphore semaphore;
    //身份证明
    private final String identification;
    public RedisSemaphore(SemaphoreInfo info, RedisTemplate redisTemplate, String identification) {
        this.info = info;
        this.redisTemplate = redisTemplate;
        this.timeKey = SEMAPHORE_TIME_KEY.concat(info.getSemaphoreName());
        this.ownerKey = SEMAPHORE_OWNER_KEY.concat(info.getSemaphoreName());
        this.counterKey = SEMAPHORE_COUNTER_KEY.concat(info.getSemaphoreName());
        this.semaphore = info.isFair() ? new FairSemaphore() : new NonfairSemaphore();
        this.identification = identification;
    }
    @Override
    public boolean tryAcquire() {
        return semaphore.tryAcquire();
    }
    @Override
    public void release() {
        semaphore.release();
    }
    private class NonfairSemaphore implements DistributedSemaphore {
        @Override
        public boolean tryAcquire() {
            ZSetOperations zsetOps = redisTemplate.opsForZSet();
            long timeMillis = System.currentTimeMillis();
            //先清除过期的信号量
            zsetOps.removeRangeByScore(timeKey, 0, timeMillis - (info.getExpire() * 1000));
            //尝试获取信号量并比较自身的排名, 如果小于许可证的数量则表示获取成功 (redis rank 指令从0开始计数)
            zsetOps.add(timeKey, identification, timeMillis);
            if (zsetOps.rank(timeKey, identification) < info.getPermits()) return true;
            //获取失败,删除掉上边添加的标识
            release();
            return false;
        }
        @Override
        public void release() {
            redisTemplate.opsForZSet().remove(timeKey, identification);
        }
    }
    private class FairSemaphore implements DistributedSemaphore {
        @Override
        public boolean tryAcquire() {
            long timeMillis = System.currentTimeMillis();
            //用于获取信号量的计数
            Long counter = redisTemplate.opsForValue().increment(counterKey, 1);
            //用流水线把这一堆命令用一次IO全部发过去
            redisTemplate.executePipelined(new SessionCallback<Object>() {
                @Override
                public <K, V> Object execute(RedisOperations<K, V> operations) throws DataAccessException {
                    ZSetOperations zsetOps = operations.opsForZSet();
                    //清除过期的信号量
                    zsetOps.removeRangeByScore(timeKey, 0, timeMillis - (info.getExpire() * 1000));
                    zsetOps.intersectAndStore(timeKey, ownerKey, timeKey);
                    //尝试获取信号量
                    zsetOps.add(timeKey, identification, timeMillis);
                    zsetOps.add(ownerKey, identification, counter);
                    return null;
                }
            });
            //这里根据 持有者集合 的分数来进行判断
            Long ownerRank = redisTemplate.opsForZSet().rank(ownerKey, identification);
            if (ownerRank<info.getPermits()) return true;
            release();
            return false;
        }
        @Override
        public void release() {
            redisTemplate.executePipelined(new SessionCallback<Object>() {
                @Override
                public <K, V> Object execute(RedisOperations<K, V> operations) throws DataAccessException {
                    ZSetOperations zetOps = operations.opsForZSet();
                    zetOps.remove(timeKey, identification);
                    zetOps.remove(ownerKey, identification);
                    return null;
                }
            });
        }
    }
}

 
 
写完后可以来一波测试, 这是测试类信息

@RunWith(SpringRunner.class)
@SpringBootTest(classes = DemoApplication.class)
public class DistributedSemaphoreTest {
    @Autowired
    private RedisTemplate redisTemplate;
    ThreadPoolExecutor pool = new ThreadPoolExecutor(10, 10, 0, TimeUnit.MINUTES,
            new LinkedBlockingQueue<>(),
            Executors.defaultThreadFactory(),
            new ThreadPoolExecutor.CallerRunsPolicy());
    @Test
    public void testNonFair() throws InterruptedException {
        SemaphoreInfo semaphoreInfo = new SemaphoreInfo("NonFair", 5, 10);
        for (int i = 0; i < 10; i++) {
            String id = String.valueOf(i);
            RedisSemaphore semaphore = new RedisSemaphore(semaphoreInfo, redisTemplate, id);
            CompletableFuture.supplyAsync(semaphore::tryAcquire, pool).thenAcceptAsync((r) -> {
                if (r) System.out.println(id + "成功获取到信号量(NonFair)~ ⭐⭐⭐");
                else System.out.println(id + "没有获取到信号量(NonFair)");
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                semaphore.release();
            }, pool);
        }
        Thread.sleep(Long.MAX_VALUE);
    }
    @Test
    public void testFair() throws ExecutionException, InterruptedException {
        SemaphoreInfo semaphoreInfo = new SemaphoreInfo("Fair", 5, 10, true);
        for (int i = 0; i < 10; i++) {
            String id = String.valueOf(i);
            RedisSemaphore semaphore = new RedisSemaphore(semaphoreInfo, redisTemplate, id);
            CompletableFuture.supplyAsync(semaphore::tryAcquire, pool).thenAcceptAsync((r) -> {
                if (r) System.out.println(id + "成功获取到信号量(Fair)~~ ⭐⭐⭐");
                else System.out.println(id + "没有获取到信号量(Fair)");
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                semaphore.release();
            }, pool);
        }
        Thread.sleep(Long.MAX_VALUE);
    }
}

 
 

结语

其实吧.  分布式信号量还是一个挺简单的东西。
相信仔细看了我这篇博客之后,  应该对于分布式信号量的实现心中都有数了。
只要知道构建一个信号量需要做什么,那么可以不用拘泥于Redis, 用其他的中间件也可以实现, 毕竟思想是相通的。
在java环境下,  用这个我提供的实现其实也完全OK 。
不过信号量也只是限流算法中的其中一种实现而已,  对于限流场景还有其他的算法可以使用, 并且就算是这个分布式信号量,在Redis层面中也是可以继续优化下去的。看之后有没有心情写吧。
 

           


CAPTCHAis initialing...

5 COMMENTS

  1. L2021-04-28 14:38

    zsetOps.intersectAndStore(timeKey, ownerKey, timeKey);
    这里是不是错了,感觉应该是
    zsetOps.intersectAndStore(timeKey, ownerKey, ownerKey);

    取交集之后,分数相加。是不是会导致信号量超标?

    • canglin2021-05-15 22:38

      我经过完整测试的哦, 下面的测试类可以跑一跑呢

      • L2021-05-25 12:14

        我用的不是java,这里没有看明白所以问下作者。

        这里根据时间戳删除了 timeKey 中过期数据
        zsetOps.removeRangeByScore(timeKey, 0, timeMillis – (info.getExpire() * 1000));

        下面应该是删除 ownerKey 中对应的过期数据吧?
        但是 ownerKey 没有时间戳,无法判断是否过期。所以通过与 timeKey 取交集,然后再存储到 ownerKey 来达到删除 ownerKey 中过期数据的目的,应该是这样吧?
        zsetOps.intersectAndStore(timeKey, ownerKey, timeKey);

        intersectAndStore 方法第三个参数是取交集后存储的key,这里不就成了 timeKey 和 ownerKey 取交集后,又存到 timeKey 里了,ownerKey 中的过期数据依旧存在。

        • R2022-03-03 14:19

          我觉的 L 说的有道理 :mrgreen:

EA PLAYER &

历史记录 [ 注意:部分数据仅限于当前浏览器 ]清空

      00:00/00:00