Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ dependencies {
// quartz
implementation 'org.springframework.boot:spring-boot-starter-quartz:3.4.1'

// bucket4j
implementation 'com.bucket4j:bucket4j_jdk17-core:8.15.0'
implementation 'com.bucket4j:bucket4j_jdk17-redis:8.14.0'
implementation 'com.bucket4j:bucket4j_jdk17-lettuce:8.15.0'

asciidoctorExt 'org.springframework.restdocs:spring-restdocs-asciidoctor'

testImplementation 'org.springframework.restdocs:spring-restdocs-mockmvc'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.sillim.recordit.config.cache;

import io.lettuce.core.RedisClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand All @@ -25,12 +26,15 @@ public RedisConnectionFactory redisConnectionFactory() {
@Bean
public RedisTemplate<String, Object> redisTemplate() {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();

redisTemplate.setConnectionFactory(redisConnectionFactory());

redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setValueSerializer(new StringRedisSerializer());

return redisTemplate;
}

@Bean
public RedisClient redisClient() {
return RedisClient.create("redis://" + host + ":" + port);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package com.sillim.recordit.config.filter;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.sillim.recordit.global.dto.response.ErrorResponse;
import com.sillim.recordit.global.exception.ErrorCode;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.BucketConfiguration;
import io.github.bucket4j.ConsumptionProbe;
import io.github.bucket4j.distributed.proxy.ProxyManager;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;

@Slf4j
@Component
@RequiredArgsConstructor
public class ApiThrottlingFilter implements Filter {

private static final List<LimitApi> LIMIT_APIS =
List.of(
LimitApi.pattern("/api/v1/invite/members/*"),
LimitApi.pattern("POST", "/api/v1/members/*/follow"),
LimitApi.pattern("POST", "/api/v1/feeds/**"));
public static final int TOO_MANY_REQUEST = 429;

private final ProxyManager<String> proxyManager;
private final BucketConfiguration bucketConfiguration;
private final AntPathMatcher antPathMatcher = new AntPathMatcher();
private final ObjectMapper objectMapper;

@Override
public void doFilter(
ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;

String auth = request.getHeader(HttpHeaders.AUTHORIZATION);
if (auth == null) {
filterChain.doFilter(servletRequest, servletResponse);
return;
}

Bucket bucket = proxyManager.getProxy(auth, () -> bucketConfiguration);
for (LimitApi limitApi : LIMIT_APIS) {
if (antPathMatcher.match(limitApi.getUrl(), request.getRequestURI())
&& (limitApi.noMethod() || limitApi.getMethod().equals(request.getMethod()))) {
checkApiToken(bucket, filterChain, servletRequest, servletResponse);
return;
}
}

filterChain.doFilter(servletRequest, servletResponse);
}

private void checkApiToken(
Bucket bucket,
FilterChain filterChain,
ServletRequest request,
ServletResponse response)
throws IOException, ServletException {
ConsumptionProbe probe = bucket.tryConsumeAndReturnRemaining(1);

if (probe.isConsumed()) {
filterChain.doFilter(request, response);
return;
}

long waitForRefill = probe.getNanosToWaitForRefill() / 1_000_000_000;

HttpServletResponse httpResponse = (HttpServletResponse) response;
httpResponse.setContentType("text/plain; charset=UTF-8");
httpResponse.setStatus(TOO_MANY_REQUEST);
httpResponse.setCharacterEncoding(StandardCharsets.UTF_8.name());

response.getWriter()
.write(
objectMapper.writeValueAsString(
ResponseEntity.status(TOO_MANY_REQUEST)
.body(
ErrorResponse.from(
ErrorCode.TOO_MANY_REQUEST,
waitForRefill + "초 뒤에 다시 시도해주세요"))));
}
}
27 changes: 27 additions & 0 deletions src/main/java/com/sillim/recordit/config/filter/LimitApi.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.sillim.recordit.config.filter;

import lombok.Getter;

@Getter
public class LimitApi {

private final String method;
private final String url;

private LimitApi(String method, String url) {
this.method = method;
this.url = url;
}

public static LimitApi pattern(String method, String url) {
return new LimitApi(method, url);
}

public static LimitApi pattern(String url) {
return new LimitApi(null, url);
}

public boolean noMethod() {
return this.method == null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.sillim.recordit.config.ratelimiter;

import io.github.bucket4j.BucketConfiguration;
import io.github.bucket4j.distributed.ExpirationAfterWriteStrategy;
import io.github.bucket4j.redis.lettuce.Bucket4jLettuce;
import io.github.bucket4j.redis.lettuce.cas.LettuceBasedProxyManager;
import io.lettuce.core.RedisClient;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.codec.ByteArrayCodec;
import io.lettuce.core.codec.RedisCodec;
import io.lettuce.core.codec.StringCodec;
import java.time.*;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
@RequiredArgsConstructor
public class RateLimiterConfig {

private final RedisClient redisClient;

@Value("${rate-limiter.capacity:5}")
private int capacity;

@Value("${rate-limiter.refill-token-amount:5}")
private int refillTokenAmount;

@Value("${rate-limiter.refill-duration-seconds:10}")
private long refillDurationSeconds;

@Value("${rate-limiter.bucket-ttl-seconds:600}")
private long bucketTTLSeconds;

@Bean
public LettuceBasedProxyManager<String> proxyManager() {
StatefulRedisConnection<String, byte[]> connect =
redisClient.connect(RedisCodec.of(StringCodec.UTF8, ByteArrayCodec.INSTANCE));
return Bucket4jLettuce.casBasedBuilder(connect)
.expirationAfterWrite(
ExpirationAfterWriteStrategy.fixedTimeToLive(
Duration.ofSeconds(bucketTTLSeconds)))
.build();
}

@Bean
public BucketConfiguration bucketConfiguration() {
return BucketConfiguration.builder()
.addLimit(
limit ->
limit.capacity(capacity)
.refillIntervally(
refillTokenAmount,
Duration.ofSeconds(refillDurationSeconds)))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public enum ErrorCode {
INVALID_ARGUMENT("ERR_GLOBAL_001", "올바르지 않은 값이 전달되었습니다."),
REQUEST_NOT_FOUND("ERR_GLOBAL_002", "요청을 찾을 수 없습니다."),
INVALID_REQUEST("ERR_GLOBAL_003", "유효하지 않은 요청입니다."),
TOO_MANY_REQUEST("ERR_GLOBAL_004", "너무 많은 요청을 보냈습니다."),
UNHANDLED_EXCEPTION("ERR_GLOBAL_999", "예상치 못한 오류가 발생했습니다."),

ID_TOKEN_UNSUPPORTED("ERR_OIDC_001", "지원되지 않는 ID Token 입니다."),
Expand Down
6 changes: 6 additions & 0 deletions src/main/resources/application-local.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ server:
min-spare: 10
max: 200

rate-limiter:
capacity: 5
refill-token-amount: 5
refill-duration-seconds: 10
bucket-ttl-seconds: 600

management:
endpoint:
health:
Expand Down
6 changes: 6 additions & 0 deletions src/main/resources/application-prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,9 @@ management:
web:
exposure:
include: health, prometheus

rate-limiter:
capacity: ${RATE_LIMITER_CAPACITY}
refill-token-amount: ${RATE_LIMITER_REFILL_TOKEN_AMOUNT}
refill-duration-seconds: ${RATE_LIMITER_REFILL_DURATION_SECONDS}
bucket-ttl-seconds: ${RATE_LIMITER_BUCKET_TTL_SECONDS}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.sillim.recordit.config.filter;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.*;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.bucket4j.BucketConfiguration;
import io.github.bucket4j.ConsumptionProbe;
import io.github.bucket4j.distributed.BucketProxy;
import io.github.bucket4j.distributed.proxy.ProxyManager;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;

@ExtendWith(MockitoExtension.class)
class ApiThrottlingFilterTest {

@Mock ProxyManager<String> proxyManager;
@Mock BucketConfiguration bucketConfiguration;
@Mock ObjectMapper objectMapper;
@InjectMocks ApiThrottlingFilter apiThrottlingFilter;

@Test
@DisplayName("URL이 제한 API 리스트에 매칭되지 않으면 필터 체인이 동작한다.")
void calledFilterChainIfURLNotInAPILimitList() throws ServletException, IOException {
MockHttpServletRequest httpServletRequest =
new MockHttpServletRequest("POST", "/api/v1/invite");
httpServletRequest.addHeader("Authorization", "Bearer test1");
MockHttpServletResponse httpServletResponse = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
BucketProxy bucketProxy = mock(BucketProxy.class);
given(proxyManager.getProxy(anyString(), any())).willReturn(bucketProxy);

for (int i = 0; i < 6; i++) {
apiThrottlingFilter.doFilter(httpServletRequest, httpServletResponse, filterChain);
}

verify(filterChain, times(6)).doFilter(httpServletRequest, httpServletResponse);
}

@Test
@DisplayName("버킷 제한보다 적은 요청을 보내면 필터 체인이 동작한다.")
void calledFilterChainIfSendRequestLessThanBucketLimit() throws ServletException, IOException {
MockHttpServletRequest httpServletRequest =
new MockHttpServletRequest("POST", "/api/v1/invite/members/1");
httpServletRequest.addHeader("Authorization", "Bearer test2");
MockHttpServletResponse httpServletResponse = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
BucketProxy bucketProxy = mock(BucketProxy.class);
AtomicInteger counter = new AtomicInteger();
given(proxyManager.getProxy(anyString(), any())).willReturn(bucketProxy);
given(bucketProxy.tryConsumeAndReturnRemaining(eq(1L)))
.willAnswer(
invo -> {
int call = counter.incrementAndGet();
if (call >= 6) {
return ConsumptionProbe.rejected(0, 5_000_000_000L, 5_000_000_000L);
}
return ConsumptionProbe.consumed(1, 5_000_000_000L);
});

for (int i = 0; i < 4; i++) {
apiThrottlingFilter.doFilter(httpServletRequest, httpServletResponse, filterChain);
}

verify(filterChain, times(4)).doFilter(httpServletRequest, httpServletResponse);
}

@Test
@DisplayName("버킷 제한보다 많은 요청을 보내면 429 statusCode가 설정된다.")
void statusCodeSet429IfSendRequestMoreThanBucketLimit() throws ServletException, IOException {
MockHttpServletRequest httpServletRequest =
new MockHttpServletRequest("POST", "/api/v1/feeds/1");
httpServletRequest.addHeader("Authorization", "Bearer test3");
MockHttpServletResponse httpServletResponse = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
BucketProxy bucketProxy = mock(BucketProxy.class);
AtomicInteger counter = new AtomicInteger();
given(proxyManager.getProxy(anyString(), any())).willReturn(bucketProxy);
given(bucketProxy.tryConsumeAndReturnRemaining(eq(1L)))
.willAnswer(
invo -> {
int call = counter.incrementAndGet();
if (call >= 6) {
return ConsumptionProbe.rejected(0, 5_000_000_000L, 5_000_000_000L);
}
return ConsumptionProbe.consumed(1, 5_000_000_000L);
});
given(objectMapper.writeValueAsString(any())).willReturn("result");

for (int i = 0; i < 6; i++) {
apiThrottlingFilter.doFilter(httpServletRequest, httpServletResponse, filterChain);
}

verify(filterChain, times(5)).doFilter(httpServletRequest, httpServletResponse);
assertThat(httpServletResponse.getStatus()).isEqualTo(429);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sillim.recordit.config.filter.ApiThrottlingFilter;
import com.sillim.recordit.config.security.filter.AuthExceptionTranslationFilter;
import com.sillim.recordit.config.security.filter.JwtAuthenticationFilter;
import com.sillim.recordit.config.security.handler.AuthenticationExceptionHandler;
Expand Down Expand Up @@ -48,6 +49,7 @@ public abstract class RestDocsTest {
@MockBean AuthenticationExceptionHandler handler;
@MockBean AuthExceptionTranslationFilter exceptionTranslationFilter;
@MockBean JwtAuthenticationFilter jwtAuthenticationFilter;
@MockBean ApiThrottlingFilter apiThrottlingFilter;

protected MockMvc mockMvc;

Expand Down
Loading