Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

/**
* 限流 AOP 切面
* 支持可重复注解,逐条执行独立的限流规则,任一规则不通过即拒绝
* 支持可重复注解,同一方法上的多条限流规则在一次 Lua 调用中原子检查和扣减
*/
@Slf4j
@Aspect
Expand Down Expand Up @@ -75,30 +75,34 @@ public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
long nowMs = System.currentTimeMillis();
String requestId = UUID.randomUUID().toString();

List<Object> keysList = new ArrayList<>(rules.length);
List<Object> args = new ArrayList<>(3 + rules.length * 3);
List<RateLimitContext> contexts = new ArrayList<>(rules.length);
args.add(String.valueOf(nowMs));
args.add(requestId);
args.add(String.valueOf(rules.length));

for (RateLimit rule : rules) {
long intervalMs = calculateIntervalMs(rule.interval(), rule.timeUnit());
String key = generateKey(className, methodName, rule.dimension());

Long result = executeRateLimitScript(key, nowMs, requestId, intervalMs, rule.count());
keysList.add(key);
args.add(String.valueOf(1));
args.add(String.valueOf(intervalMs));
args.add(String.valueOf(rule.count()));
contexts.add(new RateLimitContext(rule, key));
}

if (result == null || result == 0) {
return handleRateLimitExceeded(joinPoint, rule, key);
}
Long result = executeRateLimitScript(keysList, args.toArray());
if (result == null || result <= 0) {
RateLimitContext failedContext = resolveFailedContext(result, contexts);
return handleRateLimitExceeded(joinPoint, failedContext.rule(), failedContext.key());
}

return joinPoint.proceed();
}

private Long executeRateLimitScript(String key, long nowMs, String requestId, long intervalMs, double count) {
List<Object> keysList = Collections.singletonList(key);
Object[] args = {
String.valueOf(nowMs),
String.valueOf(1),
String.valueOf(intervalMs),
String.valueOf(count),
requestId
};

private Long executeRateLimitScript(List<Object> keysList, Object[] args) {
try {
Object resultObj = rScript.evalSha(
RScript.Mode.READ_WRITE,
Expand All @@ -125,6 +129,16 @@ private Long executeRateLimitScript(String key, long nowMs, String requestId, lo
}
}

private RateLimitContext resolveFailedContext(Long result, List<RateLimitContext> contexts) {
if (result != null && result < 0) {
int failedIndex = (int) Math.min(Math.abs(result) - 1, contexts.size() - 1L);
if (failedIndex >= 0) {
return contexts.get(failedIndex);
}
}
return contexts.get(0);
}

private long calculateIntervalMs(long interval, RateLimit.TimeUnit unit) {
return switch (unit) {
case MILLISECONDS -> interval;
Expand Down Expand Up @@ -261,4 +275,7 @@ private String getCurrentUserId() {

return "anonymous";
}

private record RateLimitContext(RateLimit rule, String key) {
}
}
108 changes: 63 additions & 45 deletions app/src/main/resources/scripts/rate_limit_single.lua
Original file line number Diff line number Diff line change
@@ -1,60 +1,78 @@
-- 单维度限流脚本
-- 基于滑动时间窗口的单维度原子限流
-- 由切面逐条调用实现多维度限流
-- 多维度限流脚本
-- 基于滑动时间窗口的原子限流
-- 同一方法上的多条 @RateLimit 规则会在一次 Lua 调用中完成检查和扣减

-- 参数说明:
-- KEYS[1]: 限流维度键
-- KEYS[i]: 第 i 条限流维度键
-- ARGV[1]: 当前时间戳(毫秒)
-- ARGV[2]: 申请令牌数
-- ARGV[3]: 时间窗口(毫秒)
-- ARGV[4]: 最大令牌数(窗口内允许的总数)
-- ARGV[5]: 请求唯一标识
-- ARGV[2]: 请求唯一标识
-- ARGV[3]: 规则数量
-- ARGV[4] 开始,每条规则 3 个参数:
-- 申请令牌数、时间窗口(毫秒)、最大令牌数(窗口内允许的总数)

local key = KEYS[1]
local now_ms = tonumber(ARGV[1])
local permits = tonumber(ARGV[2])
local interval = tonumber(ARGV[3])
local max_tokens = tonumber(ARGV[4])
local request_id = ARGV[5]

local value_key = key .. ":value"
local permits_key = key .. ":permits"

-- 获取当前可用令牌(不存在则使用 max_tokens)
local current_val = tonumber(redis.call("get", value_key)) or max_tokens

-- 回收过期令牌
local expired_values = redis.call("zrangebyscore", permits_key, 0, now_ms - interval)
if #expired_values > 0 then
local expired_count = 0
for _, v in ipairs(expired_values) do
local p = tonumber(string.match(v, ":(%d+)$"))
if p then
expired_count = expired_count + p
local request_id = ARGV[2]
local rule_count = tonumber(ARGV[3])

local current_values = {}
local permit_values = {}
local intervals = {}

-- 第一阶段:回收过期令牌并检查所有维度。任一维度不满足时,不扣任何新令牌。
for i = 1, rule_count do
local arg_index = 4 + (i - 1) * 3
local key = KEYS[i]
local permits = tonumber(ARGV[arg_index])
local interval = tonumber(ARGV[arg_index + 1])
local max_tokens = tonumber(ARGV[arg_index + 2])
local value_key = key .. ":value"
local permits_key = key .. ":permits"

local current_val = tonumber(redis.call("get", value_key)) or max_tokens

local expired_values = redis.call("zrangebyscore", permits_key, 0, now_ms - interval)
if #expired_values > 0 then
local expired_count = 0
for _, v in ipairs(expired_values) do
local p = tonumber(string.match(v, ":(%d+)$"))
if p then
expired_count = expired_count + p
end
end
end

redis.call("zremrangebyscore", permits_key, 0, now_ms - interval)
redis.call("zremrangebyscore", permits_key, 0, now_ms - interval)

if expired_count > 0 then
current_val = math.min(max_tokens, current_val + expired_count)
if expired_count > 0 then
current_val = math.min(max_tokens, current_val + expired_count)
end
end

if current_val < permits then
return -i
end
end

-- 检查可用令牌
if current_val < permits then
return 0
current_values[i] = current_val
permit_values[i] = permits
intervals[i] = interval
end

-- 扣减令牌
local permit_record = request_id .. ":" .. permits
redis.call("zadd", permits_key, now_ms, permit_record)
redis.call("set", value_key, current_val - permits)
-- 第二阶段:所有维度都通过后再统一扣减。
for i = 1, rule_count do
local key = KEYS[i]
local value_key = key .. ":value"
local permits_key = key .. ":permits"
local permits = permit_values[i]
local current_val = current_values[i]
local interval = intervals[i]

-- 设置过期时间(窗口的2倍,至少1秒)
local expire_time = math.ceil(interval * 2 / 1000)
if expire_time < 1 then expire_time = 1 end
redis.call("expire", value_key, expire_time)
redis.call("expire", permits_key, expire_time)
local permit_record = request_id .. ":" .. i .. ":" .. permits
redis.call("zadd", permits_key, now_ms, permit_record)
redis.call("set", value_key, current_val - permits)

local expire_time = math.ceil(interval * 2 / 1000)
if expire_time < 1 then expire_time = 1 end
redis.call("expire", value_key, expire_time)
redis.call("expire", permits_key, expire_time)
end

return 1
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.springframework.core.io.ClassPathResource;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
Expand Down Expand Up @@ -80,7 +81,7 @@ void testRateLimit() {
}

@Test
@DisplayName("验证多规则限流:逐条检查,任一规则不足即拒绝")
@DisplayName("验证多规则限流:任一规则不足时不扣减其他规则")
void testMultiRule() {
String globalKey = "ratelimit:test:multi:global";
String ipKey = "ratelimit:test:multi:ip";
Expand All @@ -92,12 +93,12 @@ void testMultiRule() {
redissonClient.getBucket(ipKey + ":value", StringCodec.INSTANCE).set("1");

// 第一次请求:两条规则都通过
assertEquals(1L, executeLuaScript(globalKey, globalMax));
assertEquals(1L, executeLuaScript(ipKey, ipMax));
assertEquals(1L, executeLuaScript(List.of(globalKey, ipKey), List.of(globalMax, ipMax)));

// 第二次请求:全局规则通过,IP规则拒绝(模拟短路)
assertEquals(1L, executeLuaScript(globalKey, globalMax));
assertEquals(0L, executeLuaScript(ipKey, ipMax));
// 第二次请求:IP规则拒绝,全局规则不能被误扣
assertEquals(-2L, executeLuaScript(List.of(globalKey, ipKey), List.of(globalMax, ipMax)));
assertEquals("9", redissonClient.getBucket(globalKey + ":value", StringCodec.INSTANCE).get());
assertEquals("0", redissonClient.getBucket(ipKey + ":value", StringCodec.INSTANCE).get());
}

@Test
Expand All @@ -120,24 +121,29 @@ void testIndependentCountPerDimension() {
}

private long executeLuaScript(String key, long maxCount) {
return executeLuaScript(Collections.singletonList(key), Collections.singletonList(maxCount));
}

private long executeLuaScript(List<String> keys, List<Long> maxCounts) {
RScript script = redissonClient.getScript(StringCodec.INSTANCE);

Object[] args = {
String.valueOf(System.currentTimeMillis()),
String.valueOf(1),
String.valueOf(1000),
String.valueOf(maxCount),
UUID.randomUUID().toString()
};
List<Object> args = new ArrayList<>(3 + keys.size() * 3);
args.add(String.valueOf(System.currentTimeMillis()));
args.add(UUID.randomUUID().toString());
args.add(String.valueOf(keys.size()));

List<Object> keysList = Collections.singletonList(key);
for (Long maxCount : maxCounts) {
args.add(String.valueOf(1));
args.add(String.valueOf(1000));
args.add(String.valueOf(maxCount));
}

Object result = script.evalSha(
RScript.Mode.READ_WRITE,
luaScriptSha,
RScript.ReturnType.VALUE,
keysList,
args
new ArrayList<>(keys),
args.toArray()
);

if (result instanceof Number) {
Expand Down