diff --git a/src/main/java/uk/gov/hmcts/ccd/SecurityConfiguration.java b/src/main/java/uk/gov/hmcts/ccd/SecurityConfiguration.java index 1b57114c6d..fa720e72db 100644 --- a/src/main/java/uk/gov/hmcts/ccd/SecurityConfiguration.java +++ b/src/main/java/uk/gov/hmcts/ccd/SecurityConfiguration.java @@ -18,8 +18,10 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationConverter; import org.springframework.security.oauth2.server.resource.web.authentication.BearerTokenAuthenticationFilter; import org.springframework.security.web.SecurityFilterChain; +import uk.gov.hmcts.ccd.appinsights.AppInsights; import uk.gov.hmcts.ccd.customheaders.CustomHeadersFilter; import uk.gov.hmcts.ccd.data.SecurityUtils; +import uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder; import uk.gov.hmcts.ccd.security.JwtGrantedAuthoritiesConverter; import uk.gov.hmcts.ccd.security.filters.ExceptionHandlingFilter; import uk.gov.hmcts.ccd.security.filters.SecurityLoggingFilter; @@ -113,7 +115,7 @@ public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Excepti } @Bean - JwtDecoder jwtDecoder() { + JwtDecoder jwtDecoder(AppInsights appInsights) { NimbusJwtDecoder jwtDecoder = (NimbusJwtDecoder)JwtDecoders.fromOidcIssuerLocation(issuerUri); // We are using issuerOverride instead of issuerUri as SIDAM has the wrong issuer at the moment @@ -124,7 +126,7 @@ JwtDecoder jwtDecoder() { OAuth2TokenValidator validator = new DelegatingOAuth2TokenValidator<>(withTimestamp); jwtDecoder.setJwtValidator(validator); - return jwtDecoder; + return new AppInsightsJwtDecoder(jwtDecoder, appInsights); } } diff --git a/src/main/java/uk/gov/hmcts/ccd/appinsights/AppInsights.java b/src/main/java/uk/gov/hmcts/ccd/appinsights/AppInsights.java index 8e579c6d14..fb6ff32020 100644 --- a/src/main/java/uk/gov/hmcts/ccd/appinsights/AppInsights.java +++ b/src/main/java/uk/gov/hmcts/ccd/appinsights/AppInsights.java @@ -75,6 +75,10 @@ public void trackEvent(String name, Map properties) { telemetry.trackEvent(name, properties, null); } + public void trackTrace(String message, Map customProperties, SeverityLevel severityLevel) { + telemetry.trackTrace(message, severityLevel, customProperties); + } + public void trackCallbackEvent( CallbackType callbackType, String url, String httpStatus, java.time.Duration duration) { Map properties = ImmutableMap.of( diff --git a/src/main/java/uk/gov/hmcts/ccd/security/AppInsightsJwtDecoder.java b/src/main/java/uk/gov/hmcts/ccd/security/AppInsightsJwtDecoder.java new file mode 100644 index 0000000000..e8cf4c5f24 --- /dev/null +++ b/src/main/java/uk/gov/hmcts/ccd/security/AppInsightsJwtDecoder.java @@ -0,0 +1,171 @@ +package uk.gov.hmcts.ccd.security; + +import com.microsoft.applicationinsights.telemetry.SeverityLevel; +import lombok.extern.slf4j.Slf4j; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtException; +import org.springframework.security.oauth2.jwt.JwtValidationException; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import uk.gov.hmcts.ccd.appinsights.AppInsights; + +import jakarta.servlet.http.HttpServletRequest; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +@Slf4j +public class AppInsightsJwtDecoder implements JwtDecoder { + + static final String JWT_VALIDATION_FAILURE_MESSAGE = "JWT validation failed"; + static final String FAILURE_TYPE = "failureType"; + static final String FAILURE_MESSAGE = "failureMessage"; + static final String METHOD = "method"; + static final String PATH = "path"; + static final String VALIDATION_ERRORS = "JWT validation errors"; + + private static final String NO_FAILURE_MESSAGE = "No failure message provided"; + private static final String UNKNOWN_REQUEST_VALUE = "UNKNOWN"; + + private final JwtDecoder jwtDecoder; + private final AppInsights appInsights; + + public AppInsightsJwtDecoder(JwtDecoder jwtDecoder, AppInsights appInsights) { + this.jwtDecoder = jwtDecoder; + this.appInsights = appInsights; + } + + @Override + public Jwt decode(String token) throws JwtException { + try { + return jwtDecoder.decode(token); + } catch (JwtException exception) { + logJwtValidationFailure(exception); + throw exception; + } + } + + private void logJwtValidationFailure(JwtException exception) { + String failureType = classifyJwtFailure(exception); + String failureMessage = sanitise(exception.getMessage()); + + log.warn("{}: {}", JWT_VALIDATION_FAILURE_MESSAGE, failureType); + appInsights.trackTrace( + JWT_VALIDATION_FAILURE_MESSAGE, + buildTelemetryProperties(exception, failureType, failureMessage), + SeverityLevel.Warning + ); + } + + private Map buildTelemetryProperties( + JwtException exception, + String failureType, + String failureMessage + ) { + Map properties = new HashMap<>(); + properties.put(FAILURE_TYPE, failureType); + properties.put(FAILURE_MESSAGE, failureMessage); + properties.put(METHOD, currentRequestMethod()); + properties.put(PATH, currentRequestPath()); + + if (exception instanceof JwtValidationException jwtValidationException) { + properties.put(VALIDATION_ERRORS, validationErrors(jwtValidationException)); + } + + return properties; + } + + private String currentRequestMethod() { + HttpServletRequest request = currentRequest(); + return requestValue(request == null ? null : request.getMethod()); + } + + private String currentRequestPath() { + HttpServletRequest request = currentRequest(); + return requestValue(request == null ? null : request.getRequestURI()); + } + + private HttpServletRequest currentRequest() { + if (RequestContextHolder.getRequestAttributes() instanceof ServletRequestAttributes requestAttributes) { + return requestAttributes.getRequest(); + } + + return null; + } + + private String requestValue(String value) { + return value == null || value.isBlank() ? UNKNOWN_REQUEST_VALUE : value; + } + + private String validationErrors(JwtValidationException exception) { + return exception.getErrors() + .stream() + .map(this::errorDescription) + .collect(Collectors.joining("; ")); + } + + private String errorDescription(OAuth2Error error) { + if (error.getDescription() != null && !error.getDescription().isBlank()) { + return sanitise(error.getDescription()); + } + + return sanitise(error.getErrorCode()); + } + + private String sanitise(String message) { + if (message == null || message.isBlank()) { + return NO_FAILURE_MESSAGE; + } + + return message.replaceAll("\\s+", " "); + } + + private String classifyJwtFailure(Exception e) { + String msg = jwtFailureDetails(e).toLowerCase(Locale.ROOT); + + if (msg.isBlank()) { + return "UNKNOWN"; + } + + if (msg.contains("expired")) { + return "TOKEN_EXPIRED"; + } + if (msg.contains("signature")) { + return "INVALID_SIGNATURE"; + } + if (msg.contains("audience") || msg.contains("aud claim") || msg.contains("\"aud\"")) { + return "INVALID_AUDIENCE"; + } + if (msg.contains("issuer") || msg.contains("iss claim") || msg.contains("\"iss\"")) { + return "INVALID_ISSUER"; + } + + return "OTHER"; + } + + private String jwtFailureDetails(Exception exception) { + StringBuilder details = new StringBuilder(); + appendIfPresent(details, exception.getMessage()); + + if (exception instanceof JwtValidationException jwtValidationException) { + jwtValidationException.getErrors().forEach(error -> { + appendIfPresent(details, error.getDescription()); + appendIfPresent(details, error.getErrorCode()); + }); + } + + return details.toString(); + } + + private void appendIfPresent(StringBuilder details, String value) { + if (value != null && !value.isBlank()) { + if (details.length() > 0) { + details.append(' '); + } + details.append(value); + } + } +} diff --git a/src/test/java/uk/gov/hmcts/ccd/appinsights/AppInsightsTest.java b/src/test/java/uk/gov/hmcts/ccd/appinsights/AppInsightsTest.java index 63d9786255..6cb8422dfb 100644 --- a/src/test/java/uk/gov/hmcts/ccd/appinsights/AppInsightsTest.java +++ b/src/test/java/uk/gov/hmcts/ccd/appinsights/AppInsightsTest.java @@ -180,6 +180,21 @@ public void trackException_complex_shouldUseExceptionTelemetry_withCustomPropert assertThat(exceptionTelemetry.getException(), is(equalTo(testException))); } + @Test + public void trackTrace_shouldCallTrackTrace() { + + // ARRANGE + String message = "Test trace"; + Map customProperties = new HashMap<>(); + customProperties.put("test1", "Test property 1"); + + // ACT + classUnderTest.trackTrace(message, customProperties, SeverityLevel.Warning); + + // ASSERT + verify(telemetryClient, times(1)).trackTrace(message, SeverityLevel.Warning, customProperties); + } + @Test public void trackDependency_simple_shouldCallTrackDependency_successfulDependency() { diff --git a/src/test/java/uk/gov/hmcts/ccd/security/AppInsightsJwtDecoderTest.java b/src/test/java/uk/gov/hmcts/ccd/security/AppInsightsJwtDecoderTest.java new file mode 100644 index 0000000000..268b8b629d --- /dev/null +++ b/src/test/java/uk/gov/hmcts/ccd/security/AppInsightsJwtDecoderTest.java @@ -0,0 +1,202 @@ +package uk.gov.hmcts.ccd.security; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.read.ListAppender; +import com.microsoft.applicationinsights.telemetry.SeverityLevel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.slf4j.LoggerFactory; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.jwt.BadJwtException; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.JwtValidationException; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import uk.gov.hmcts.ccd.appinsights.AppInsights; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder.FAILURE_MESSAGE; +import static uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder.FAILURE_TYPE; +import static uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder.JWT_VALIDATION_FAILURE_MESSAGE; +import static uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder.METHOD; +import static uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder.PATH; +import static uk.gov.hmcts.ccd.security.AppInsightsJwtDecoder.VALIDATION_ERRORS; + +@ExtendWith(MockitoExtension.class) +class AppInsightsJwtDecoderTest { + + private static final String TOKEN = "jwt-token"; + private static final String UNKNOWN = "UNKNOWN"; + + @Mock + private JwtDecoder jwtDecoder; + + @Mock + private AppInsights appInsights; + + @Mock + private Jwt jwt; + + private AppInsightsJwtDecoder appInsightsJwtDecoder; + private Logger logger; + private ListAppender listAppender; + + @BeforeEach + void setUp() { + appInsightsJwtDecoder = new AppInsightsJwtDecoder(jwtDecoder, appInsights); + + logger = (Logger) LoggerFactory.getLogger(AppInsightsJwtDecoder.class); + listAppender = new ListAppender<>(); + listAppender.start(); + logger.addAppender(listAppender); + } + + @AfterEach + void tearDown() { + listAppender.stop(); + logger.detachAppender(listAppender); + RequestContextHolder.resetRequestAttributes(); + } + + @Test + void decodeShouldReturnJwtWhenDelegateDecodesToken() { + when(jwtDecoder.decode(TOKEN)).thenReturn(jwt); + + Jwt decodedJwt = appInsightsJwtDecoder.decode(TOKEN); + + assertThat(decodedJwt).isEqualTo(jwt); + verifyNoInteractions(appInsights); + } + + @Test + void decodeShouldLogJwtFailureToAppInsightsAndRethrowException() { + BadJwtException exception = new BadJwtException("Signed JWT rejected: Invalid signature"); + when(jwtDecoder.decode(TOKEN)).thenThrow(exception); + setCurrentRequest("GET", "/cases/123"); + + assertThatThrownBy(() -> appInsightsJwtDecoder.decode(TOKEN)).isSameAs(exception); + + Map properties = captureAppInsightsProperties(JWT_VALIDATION_FAILURE_MESSAGE); + + assertThat(properties.get(FAILURE_TYPE)).isEqualTo("INVALID_SIGNATURE"); + assertThat(properties.get(FAILURE_MESSAGE)).isEqualTo("Signed JWT rejected: Invalid signature"); + assertThat(properties.get(METHOD)).isEqualTo("GET"); + assertThat(properties.get(PATH)).isEqualTo("/cases/123"); + assertThat(properties).doesNotContainKey(VALIDATION_ERRORS); + + assertThat(listAppender.list).hasSize(1); + assertThat(listAppender.list.get(0).getLevel()).isEqualTo(Level.WARN); + assertThat(listAppender.list.get(0).getFormattedMessage()) + .contains("JWT validation failed: INVALID_SIGNATURE"); + } + + @Test + void decodeShouldIncludeValidationErrorDescriptionsInAppInsightsProperties() { + OAuth2Error expiredToken = new OAuth2Error("invalid_token", "Jwt expired at 2026-04-28T10:00:00Z", null); + OAuth2Error invalidClaim = new OAuth2Error("invalid_token", "The iss claim is not valid", null); + JwtValidationException exception = new JwtValidationException( + "Jwt validation failed", + List.of(expiredToken, invalidClaim) + ); + when(jwtDecoder.decode(TOKEN)).thenThrow(exception); + setCurrentRequest("POST", "/caseworkers/abc/jurisdictions"); + + assertThatThrownBy(() -> appInsightsJwtDecoder.decode(TOKEN)).isSameAs(exception); + + Map properties = captureAppInsightsProperties(JWT_VALIDATION_FAILURE_MESSAGE); + + assertThat(properties.get(FAILURE_TYPE)).isEqualTo("TOKEN_EXPIRED"); + assertThat(properties.get(FAILURE_MESSAGE)).isEqualTo("Jwt validation failed"); + assertThat(properties.get(METHOD)).isEqualTo("POST"); + assertThat(properties.get(PATH)).isEqualTo("/caseworkers/abc/jurisdictions"); + assertThat(properties.get(VALIDATION_ERRORS)) + .isEqualTo("Jwt expired at 2026-04-28T10:00:00Z; The iss claim is not valid"); + } + + @ParameterizedTest + @CsvSource({ + "The aud claim is not valid, INVALID_AUDIENCE", + "The iss claim is not valid, INVALID_ISSUER", + "Malformed JWT, OTHER" + }) + void decodeShouldClassifyJwtFailures(String failureMessage, String expectedFailureType) { + BadJwtException exception = new BadJwtException(failureMessage); + when(jwtDecoder.decode(TOKEN)).thenThrow(exception); + + assertThatThrownBy(() -> appInsightsJwtDecoder.decode(TOKEN)).isSameAs(exception); + + Map properties = captureAppInsightsProperties(JWT_VALIDATION_FAILURE_MESSAGE); + + assertThat(properties.get(FAILURE_TYPE)).isEqualTo(expectedFailureType); + assertThat(properties.get(FAILURE_MESSAGE)).isEqualTo(failureMessage); + assertThat(properties.get(METHOD)).isEqualTo(UNKNOWN); + assertThat(properties.get(PATH)).isEqualTo(UNKNOWN); + } + + @Test + void decodeShouldClassifyJwtFailureWithNoMessageAsUnknown() { + BadJwtException exception = new BadJwtException(null); + when(jwtDecoder.decode(TOKEN)).thenThrow(exception); + + assertThatThrownBy(() -> appInsightsJwtDecoder.decode(TOKEN)).isSameAs(exception); + + Map properties = captureAppInsightsProperties(JWT_VALIDATION_FAILURE_MESSAGE); + + assertThat(properties.get(FAILURE_TYPE)).isEqualTo(UNKNOWN); + assertThat(properties.get(FAILURE_MESSAGE)).isEqualTo("No failure message provided"); + } + + @Test + void decodeShouldUseValidationErrorCodeWhenDescriptionIsMissing() { + OAuth2Error validationError = new OAuth2Error("invalid_token", null, null); + JwtValidationException exception = new JwtValidationException( + "Jwt validation failed", + List.of(validationError) + ); + when(jwtDecoder.decode(TOKEN)).thenThrow(exception); + + assertThatThrownBy(() -> appInsightsJwtDecoder.decode(TOKEN)).isSameAs(exception); + + Map properties = captureAppInsightsProperties(JWT_VALIDATION_FAILURE_MESSAGE); + + assertThat(properties.get(FAILURE_TYPE)).isEqualTo("OTHER"); + assertThat(properties.get(VALIDATION_ERRORS)).isEqualTo("invalid_token"); + } + + @SuppressWarnings("unchecked") + private Map captureAppInsightsProperties(String expectedMessage) { + ArgumentCaptor> propertiesCaptor = ArgumentCaptor.forClass(Map.class); + + verify(appInsights).trackTrace( + eq(expectedMessage), + propertiesCaptor.capture(), + eq(SeverityLevel.Warning) + ); + + return propertiesCaptor.getValue(); + } + + private void setCurrentRequest(String method, String requestUri) { + MockHttpServletRequest request = new MockHttpServletRequest(method, requestUri); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request)); + } +}