diff --git a/src/main/java/org/commcare/formplayer/application/DebuggerController.java b/src/main/java/org/commcare/formplayer/application/DebuggerController.java index 8dde44172..28f352e85 100644 --- a/src/main/java/org/commcare/formplayer/application/DebuggerController.java +++ b/src/main/java/org/commcare/formplayer/application/DebuggerController.java @@ -63,8 +63,7 @@ public class DebuggerController extends AbstractBaseController { @UserRestore @ConfigureStorageFromSession public DebuggerFormattedQuestionsResponseBean getFormattedQuesitons( - @RequestBody SessionRequestBean debuggerRequest, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken) throws Exception { + @RequestBody SessionRequestBean debuggerRequest) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById(debuggerRequest.getSessionId()); SerializableMenuSession serializableMenuSession = menuSessionService.getSessionById(serializableFormSession.getMenuSessionId()); FormSession formSession = formSessionFactory.getFormSession(serializableFormSession); @@ -92,7 +91,6 @@ public DebuggerFormattedQuestionsResponseBean getFormattedQuesitons( @AppInstall public MenuDebuggerContentResponseBean menuDebuggerContent( @RequestBody SessionNavigationBean debuggerMenuRequest, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken, HttpServletRequest request) throws Exception { MenuSession menuSession = menuSessionFactory.getMenuSessionFromBean(debuggerMenuRequest); @@ -114,7 +112,6 @@ public MenuDebuggerContentResponseBean menuDebuggerContent( @UserRestore @AppInstall public EvaluateXPathResponseBean menuEvaluateXpath(@RequestBody EvaluateXPathMenuRequestBean evaluateXPathRequestBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken, HttpServletRequest request) throws Exception { MenuSession menuSession = menuSessionFactory.getMenuSessionFromBean(evaluateXPathRequestBean); BaseResponseBean responseBean = runnerService.advanceSessionWithSelections( @@ -144,8 +141,8 @@ public EvaluateXPathResponseBean menuEvaluateXpath(@RequestBody EvaluateXPathMen @UserLock @UserRestore @ConfigureStorageFromSession - public EvaluateXPathResponseBean evaluateXpath(@RequestBody EvaluateXPathRequestBean evaluateXPathRequestBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken) throws Exception { + public EvaluateXPathResponseBean evaluateXpath(@RequestBody EvaluateXPathRequestBean evaluateXPathRequestBean) + throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById(evaluateXPathRequestBean.getSessionId()); FormSession formEntrySession = formSessionFactory.getFormSession(serializableFormSession); EvaluateXPathResponseBean evaluateXPathResponseBean = new EvaluateXPathResponseBean( diff --git a/src/main/java/org/commcare/formplayer/application/FormController.java b/src/main/java/org/commcare/formplayer/application/FormController.java index 1d32498d1..0384c58ed 100644 --- a/src/main/java/org/commcare/formplayer/application/FormController.java +++ b/src/main/java/org/commcare/formplayer/application/FormController.java @@ -122,8 +122,7 @@ public static HashMap validateAnswers(FormEntryController for @RequestMapping(value = Constants.URL_NEW_SESSION, method = RequestMethod.POST) @UserLock @UserRestore - public NewFormResponse newFormResponse(@RequestBody NewSessionRequestBean newSessionBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public NewFormResponse newFormResponse(@RequestBody NewSessionRequestBean newSessionBean) throws Exception { String postUrl = host + newSessionBean.getPostUrl(); return newFormResponseFactory.getResponse(newSessionBean, postUrl); @@ -133,8 +132,7 @@ public NewFormResponse newFormResponse(@RequestBody NewSessionRequestBean newSes @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryResponseBean changeLocale(@RequestBody ChangeLocaleRequestBean changeLocaleBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryResponseBean changeLocale(@RequestBody ChangeLocaleRequestBean changeLocaleBean) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( changeLocaleBean.getSessionId()); @@ -150,8 +148,7 @@ public FormEntryResponseBean changeLocale(@RequestBody ChangeLocaleRequestBean c @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryResponseBean answerQuestion(@RequestBody AnswerQuestionRequestBean answerQuestionBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryResponseBean answerQuestion(@RequestBody AnswerQuestionRequestBean answerQuestionBean) throws Exception { return saveAnswer(answerQuestionBean, null, false); } @@ -166,7 +163,6 @@ public FormEntryResponseBean answerQuestion(@RequestBody AnswerQuestionRequestBe @ConfigureStorageFromSession public FormEntryResponseBean answerMediaQuestion( @RequestPart(PART_ANSWER) AnswerQuestionRequestBean answerQuestionBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken, @RequestPart(PART_FILE) MultipartFile file) throws Exception { return saveAnswer(answerQuestionBean, file, false); @@ -176,9 +172,7 @@ public FormEntryResponseBean answerMediaQuestion( @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryResponseBean clearAnswer( - @RequestBody AnswerQuestionRequestBean answerQuestionBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryResponseBean clearAnswer(@RequestBody AnswerQuestionRequestBean answerQuestionBean) throws Exception{ return saveAnswer(answerQuestionBean, null, true); } @@ -255,8 +249,8 @@ private FormEntryResponseBean saveAnswer(AnswerQuestionRequestBean answerQuestio @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryResponseBean newRepeat(@RequestBody RepeatRequestBean newRepeatRequestBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken) throws Exception { + public FormEntryResponseBean newRepeat(@RequestBody RepeatRequestBean newRepeatRequestBean) + throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( newRepeatRequestBean.getSessionId()); FormSession formEntrySession = formSessionFactory.getFormSession(serializableFormSession); @@ -267,6 +261,7 @@ public FormEntryResponseBean newRepeat(@RequestBody RepeatRequestBean newRepeatR FormEntryResponseBean responseBean = mapper.readValue(response.toString(), FormEntryResponseBean.class); responseBean.setTitle(serializableFormSession.getTitle()); responseBean.setInstanceXml(new InstanceXmlBean(serializableFormSession.getInstanceXml())); + responseBean.setSessionId(serializableFormSession.getId()); log.info("New response: " + responseBean); return responseBean; } @@ -275,8 +270,7 @@ public FormEntryResponseBean newRepeat(@RequestBody RepeatRequestBean newRepeatR @ResponseBody @UserRestore @ConfigureStorageFromSession - public FormEntryResponseBean deleteRepeat(@RequestBody RepeatRequestBean deleteRepeatRequestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryResponseBean deleteRepeat(@RequestBody RepeatRequestBean deleteRepeatRequestBean) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( deleteRepeatRequestBean.getSessionId()); @@ -288,6 +282,7 @@ public FormEntryResponseBean deleteRepeat(@RequestBody RepeatRequestBean deleteR FormEntryResponseBean responseBean = mapper.readValue(response.toString(), FormEntryResponseBean.class); responseBean.setTitle(serializableFormSession.getTitle()); responseBean.setInstanceXml(new InstanceXmlBean(serializableFormSession.getInstanceXml())); + responseBean.setSessionId(serializableFormSession.getId()); return responseBean; } @@ -296,8 +291,7 @@ public FormEntryResponseBean deleteRepeat(@RequestBody RepeatRequestBean deleteR @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryNavigationResponseBean getNext(@RequestBody SessionRequestBean requestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryNavigationResponseBean getNext(@RequestBody SessionRequestBean requestBean) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( requestBean.getSessionId()); @@ -313,8 +307,7 @@ public FormEntryNavigationResponseBean getNext(@RequestBody SessionRequestBean r @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryNavigationResponseBean getNextSms(@RequestBody SessionRequestBean requestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryNavigationResponseBean getNextSms(@RequestBody SessionRequestBean requestBean) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( requestBean.getSessionId()); @@ -329,8 +322,7 @@ public FormEntryNavigationResponseBean getNextSms(@RequestBody SessionRequestBea @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryNavigationResponseBean getPrevious(@RequestBody SessionRequestBean requestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryNavigationResponseBean getPrevious(@RequestBody SessionRequestBean requestBean) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( requestBean.getSessionId()); @@ -350,8 +342,7 @@ public FormEntryNavigationResponseBean getPrevious(@RequestBody SessionRequestBe @UserLock @UserRestore @ConfigureStorageFromSession - public GetInstanceResponseBean getRawInstance(@RequestBody SessionRequestBean requestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public GetInstanceResponseBean getRawInstance(@RequestBody SessionRequestBean requestBean) throws Exception { SerializableFormSession serializableFormSession = formSessionService.getSessionById( requestBean.getSessionId()); @@ -365,8 +356,7 @@ public GetInstanceResponseBean getRawInstance(@RequestBody SessionRequestBean re @UserLock @UserRestore @ConfigureStorageFromSession - public FormEntryNavigationResponseBean getCurrent(@RequestBody SessionRequestBean requestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) + public FormEntryNavigationResponseBean getCurrent(@RequestBody SessionRequestBean requestBean) throws Exception { org.commcare.formplayer.objects.SerializableFormSession serializableFormSession = formSessionService.getSessionById(requestBean.getSessionId()); diff --git a/src/main/java/org/commcare/formplayer/application/FormSubmissionController.java b/src/main/java/org/commcare/formplayer/application/FormSubmissionController.java index 4d96906b3..e7835861a 100644 --- a/src/main/java/org/commcare/formplayer/application/FormSubmissionController.java +++ b/src/main/java/org/commcare/formplayer/application/FormSubmissionController.java @@ -34,7 +34,6 @@ public class FormSubmissionController extends AbstractBaseController { @UserRestore @ConfigureStorageFromSession public SubmitResponseBean submitForm(@RequestBody SubmitRequestBean submitRequestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken, HttpServletRequest request) throws Exception { return formSubmissionHelper.processAndSubmitForm(request, submitRequestBean.getSessionId(), submitRequestBean.getDomain(), submitRequestBean.isPrevalidated(), submitRequestBean.getAnswers()); diff --git a/src/main/java/org/commcare/formplayer/application/IncompleteSessionController.java b/src/main/java/org/commcare/formplayer/application/IncompleteSessionController.java index 7cc5617f6..dfc45ab7a 100644 --- a/src/main/java/org/commcare/formplayer/application/IncompleteSessionController.java +++ b/src/main/java/org/commcare/formplayer/application/IncompleteSessionController.java @@ -37,8 +37,8 @@ public class IncompleteSessionController extends AbstractBaseController { @RequestMapping(value = Constants.URL_INCOMPLETE_SESSION, method = RequestMethod.POST) @UserLock @UserRestore - public NewFormResponse openIncompleteForm(@RequestBody SessionRequestBean incompleteSessionRequestBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken) throws Exception { + public NewFormResponse openIncompleteForm(@RequestBody SessionRequestBean incompleteSessionRequestBean) + throws Exception { SerializableFormSession session = formSessionService.getSessionById(incompleteSessionRequestBean.getSessionId()); storageFactory.configure(session); return newFormResponseFactory.getResponse(session, commCareSessionFactory.getCommCareSession(session.getMenuSessionId())); @@ -46,8 +46,7 @@ public NewFormResponse openIncompleteForm(@RequestBody SessionRequestBean incomp @RequestMapping(value = Constants.URL_GET_SESSIONS, method = RequestMethod.POST) @UserRestore - public GetSessionsResponse getSessions(@RequestBody FormsSessionsRequestBean getSessionRequest, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken) throws Exception { + public GetSessionsResponse getSessions(@RequestBody FormsSessionsRequestBean getSessionRequest) throws Exception { String scrubbedUsername = TableBuilder.scrubName(getSessionRequest.getUsername()); List formplayerSessions = formSessionService.getSessionsForUser(scrubbedUsername, getSessionRequest); @@ -67,8 +66,7 @@ public GetSessionsResponse getSessions(@RequestBody FormsSessionsRequestBean get @RequestMapping(value = Constants.URL_DELETE_INCOMPLETE_SESSION, method = RequestMethod.POST) public NotificationMessage deleteIncompleteForm( - @RequestBody SessionRequestBean incompleteSessionRequestBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken) throws Exception { + @RequestBody SessionRequestBean incompleteSessionRequestBean) throws Exception { deleteSession(incompleteSessionRequestBean.getSessionId()); return new NotificationMessage("Successfully deleted incomplete form.", false, NotificationMessage.Tag.incomplete_form); } diff --git a/src/main/java/org/commcare/formplayer/application/MenuController.java b/src/main/java/org/commcare/formplayer/application/MenuController.java index e638b38bc..10c3c846a 100644 --- a/src/main/java/org/commcare/formplayer/application/MenuController.java +++ b/src/main/java/org/commcare/formplayer/application/MenuController.java @@ -68,7 +68,6 @@ public class MenuController extends AbstractBaseController { @UserRestore @AppInstall public EntityDetailListResponse getDetails(@RequestBody SessionNavigationBean sessionNavigationBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken, HttpServletRequest request) throws Exception { MenuSession menuSession = menuSessionFactory.getMenuSessionFromBean(sessionNavigationBean); boolean isFuzzySearch = storageFactory.getPropertyManager().isFuzzySearchEnabled(); @@ -165,7 +164,6 @@ public EntityDetailListResponse getDetails(@RequestBody SessionNavigationBean se @UserRestore @AppInstall public BaseResponseBean navigateSessionWithAuth(@RequestBody SessionNavigationBean sessionNavigationBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken, HttpServletRequest request) throws Exception { String[] selections = sessionNavigationBean.getSelections(); MenuSession menuSession; @@ -230,7 +228,6 @@ private static T setLocationNeeds(T res @UserRestore @AppInstall public BaseResponseBean navigateToEndpoint(@RequestBody SessionNavigationBean sessionNavigationBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken, HttpServletRequest request) throws Exception { // Apps using aggressive syncs are likely to hit a sync whenever using endpoint-based navigation, // since they use it to jump between different sandboxes. Turn it off. diff --git a/src/main/java/org/commcare/formplayer/application/UtilController.java b/src/main/java/org/commcare/formplayer/application/UtilController.java index 501307fe3..a02272190 100644 --- a/src/main/java/org/commcare/formplayer/application/UtilController.java +++ b/src/main/java/org/commcare/formplayer/application/UtilController.java @@ -79,8 +79,7 @@ public class UtilController { @RequestMapping(value = Constants.URL_SYNC_DB, method = RequestMethod.POST) @UserLock @UserRestore - public SyncDbResponseBean syncUserDb(@RequestBody SyncDbRequestBean syncRequest, - @CookieValue(value = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) throws Exception { + public SyncDbResponseBean syncUserDb(@RequestBody SyncDbRequestBean syncRequest) throws Exception { restoreFactory.performTimedSync(); return new SyncDbResponseBean(); } @@ -90,7 +89,6 @@ public SyncDbResponseBean syncUserDb(@RequestBody SyncDbRequestBean syncRequest, @UserRestore @AppInstall public SyncDbResponseBean scheduleSync(@RequestBody SessionNavigationBean sessionNavigationBean, - @CookieValue(Constants.POSTGRES_DJANGO_SESSION_ID) String authToken, HttpServletRequest request) throws Exception { SyncDbResponseBean response = new SyncDbResponseBean(); if (restoreFactory.isRestoreXmlExpired()) { diff --git a/src/main/java/org/commcare/formplayer/aspects/UserRestoreAspect.java b/src/main/java/org/commcare/formplayer/aspects/UserRestoreAspect.java index a61c0e96e..bbf55f56a 100644 --- a/src/main/java/org/commcare/formplayer/aspects/UserRestoreAspect.java +++ b/src/main/java/org/commcare/formplayer/aspects/UserRestoreAspect.java @@ -3,8 +3,6 @@ import io.opentracing.Span; import io.opentracing.util.GlobalTracer; import io.sentry.Sentry; -import org.commcare.formplayer.auth.DjangoAuth; -import org.commcare.formplayer.auth.HqAuth; import org.commcare.formplayer.beans.AuthenticatedRequestBean; import org.commcare.formplayer.beans.SessionRequestBean; import org.commcare.formplayer.objects.SerializableFormSession; @@ -61,10 +59,7 @@ public void configureRestoreFactory(JoinPoint joinPoint) throws Throwable { String.format("Could not configure RestoreFactory with invalid request %s", Arrays.toString(args))); } AuthenticatedRequestBean requestBean = (AuthenticatedRequestBean) args[0]; - - HqAuth auth = getHqAuth((String) args[1]); - - configureRestoreFactory(requestBean, auth); + configureRestoreFactory(requestBean); configureSentryScope(restoreFactory); if (requestBean.isMustRestore()) { @@ -80,15 +75,15 @@ private void configureSentryScope(RestoreFactory restoreFactory) { }); } - private void configureRestoreFactory(AuthenticatedRequestBean requestBean, HqAuth auth) throws Exception { + private void configureRestoreFactory(AuthenticatedRequestBean requestBean) throws Exception { if (requestBean.getRestoreAsCaseId() != null) { // SMS user filling out a form as a case - restoreFactory.configure(requestBean.getDomain(), requestBean.getRestoreAsCaseId(), auth); + restoreFactory.configure(requestBean.getDomain(), requestBean.getRestoreAsCaseId()); return; } if (requestBean.getUsername() != null && requestBean.getDomain() != null) { // Normal restore path - restoreFactory.configure(requestBean, auth); + restoreFactory.configure(requestBean); final Span span = GlobalTracer.get().activeSpan(); if (span != null && (span instanceof MutableSpan)) { MutableSpan localRootSpan = ((MutableSpan) span).getLocalRootSpan(); @@ -101,9 +96,9 @@ private void configureRestoreFactory(AuthenticatedRequestBean requestBean, HqAut SerializableFormSession formSession = formSessionService.getSessionById(sessionId); if (formSession.getRestoreAsCaseId() != null) { - restoreFactory.configure(formSession.getDomain(), formSession.getRestoreAsCaseId(), auth); + restoreFactory.configure(formSession.getDomain(), formSession.getRestoreAsCaseId()); } else { - restoreFactory.configure(formSession.getUsername(), formSession.getDomain(), formSession.getAsUser(), auth); + restoreFactory.configure(formSession.getUsername(), formSession.getDomain(), formSession.getAsUser()); } } else { throw new Exception("Unable to configure restore factory"); @@ -114,13 +109,4 @@ private void configureRestoreFactory(AuthenticatedRequestBean requestBean, HqAut public void closeRestoreFactory(JoinPoint joinPoint) throws Throwable { restoreFactory.getSQLiteDB().closeConnection(); } - - private HqAuth getHqAuth(String sessionToken) { - if (sessionToken != null) { - return new DjangoAuth(sessionToken); - } - // Null auth expected for SMS requests - return null; - } - } diff --git a/src/main/java/org/commcare/formplayer/auth/DjangoAuth.java b/src/main/java/org/commcare/formplayer/auth/DjangoAuth.java deleted file mode 100644 index e3d0e14c0..000000000 --- a/src/main/java/org/commcare/formplayer/auth/DjangoAuth.java +++ /dev/null @@ -1,34 +0,0 @@ -package org.commcare.formplayer.auth; - -import org.commcare.formplayer.util.Constants; -import org.springframework.http.HttpHeaders; - -/** - * Class for storing a Django auth key and returning its respective headers - */ -public class DjangoAuth implements HqAuth { - - private final String authKey; - - public DjangoAuth(String authKey) { - this.authKey = authKey; - } - - - // We seem to need all of these headers at different times. TODO WSP figure that out - @Override - public HttpHeaders getAuthHeaders() { - return new HttpHeaders() { - { - add("Cookie", Constants.POSTGRES_DJANGO_SESSION_ID + "=" + authKey); - add(Constants.POSTGRES_DJANGO_SESSION_ID, authKey); - add("Authorization", Constants.POSTGRES_DJANGO_SESSION_ID + "=" + authKey); - } - }; - } - - @Override - public String toString() { - return "DjangoAuth key=" + authKey; - } -} diff --git a/src/main/java/org/commcare/formplayer/auth/HqAuth.java b/src/main/java/org/commcare/formplayer/auth/HqAuth.java deleted file mode 100644 index 1f81d7701..000000000 --- a/src/main/java/org/commcare/formplayer/auth/HqAuth.java +++ /dev/null @@ -1,10 +0,0 @@ -package org.commcare.formplayer.auth; - -import org.springframework.http.HttpHeaders; - -/** - * Created by willpride on 1/13/16. - */ -public interface HqAuth { - HttpHeaders getAuthHeaders(); -} diff --git a/src/main/java/org/commcare/formplayer/beans/auth/HqUserDetailsBean.java b/src/main/java/org/commcare/formplayer/beans/auth/HqUserDetailsBean.java index bee81b825..94f90bf7f 100644 --- a/src/main/java/org/commcare/formplayer/beans/auth/HqUserDetailsBean.java +++ b/src/main/java/org/commcare/formplayer/beans/auth/HqUserDetailsBean.java @@ -32,8 +32,9 @@ public class HqUserDetailsBean implements UserDetails { public HqUserDetailsBean() { } - public HqUserDetailsBean(String domain, String username) { + public HqUserDetailsBean(String sessionId, String domain, String username) { this(domain, new String[]{domain}, username, false, new String[]{}, new String[]{}); + this.authToken = sessionId; } public HqUserDetailsBean(String domain, String[] domains, String username, boolean isSuperuser, diff --git a/src/main/java/org/commcare/formplayer/hq/models/PostgresUser.java b/src/main/java/org/commcare/formplayer/hq/models/PostgresUser.java deleted file mode 100644 index 990d351c3..000000000 --- a/src/main/java/org/commcare/formplayer/hq/models/PostgresUser.java +++ /dev/null @@ -1,47 +0,0 @@ -package org.commcare.formplayer.hq.models; - - -/** - * Created by benrudolph on 9/7/16. - */ -public class PostgresUser { - - private String username; - private int userId; - private boolean isSuperuser; - private String authToken; - - public PostgresUser(int userId, String username, boolean isSuperuser) { - this.userId = userId; - this.username = username; - this.isSuperuser = isSuperuser; - } - - public PostgresUser(int userId, String username, boolean isSuperuser, String authToken) { - this.userId = userId; - this.username = username; - this.isSuperuser = isSuperuser; - this.authToken = authToken; - } - - public String getUsername() { - return username; - } - - public int getUserId() { - return userId; - } - - public boolean isSuperuser() { - return isSuperuser; - } - - public String getAuthToken() { - return authToken; - } - - public void setAuthToken(String authToken) { - this.authToken = authToken; - } -} - diff --git a/src/main/java/org/commcare/formplayer/services/RestoreFactory.java b/src/main/java/org/commcare/formplayer/services/RestoreFactory.java index ed8e9c7ec..9810cc6f6 100644 --- a/src/main/java/org/commcare/formplayer/services/RestoreFactory.java +++ b/src/main/java/org/commcare/formplayer/services/RestoreFactory.java @@ -1,10 +1,9 @@ package org.commcare.formplayer.services; -import static org.commcare.formplayer.util.Constants.TOGGLE_INCLUDE_STATE_HASH; - import com.google.common.collect.ImmutableMap; import com.timgroup.statsd.StatsDClient; - +import datadog.trace.api.Trace; +import io.sentry.SentryLevel; import org.apache.commons.io.IOUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -13,7 +12,6 @@ import org.commcare.core.parse.ParseUtils; import org.commcare.formplayer.DbUtils; import org.commcare.formplayer.api.process.FormRecordProcessorHelper; -import org.commcare.formplayer.auth.HqAuth; import org.commcare.formplayer.beans.AuthenticatedRequestBean; import org.commcare.formplayer.beans.auth.FeatureFlagChecker; import org.commcare.formplayer.engine.FormplayerTransactionParserFactory; @@ -26,14 +24,12 @@ import org.commcare.formplayer.sqlitedb.UserDB; import org.commcare.formplayer.util.Constants; import org.commcare.formplayer.util.FormplayerSentry; -import org.commcare.formplayer.util.RequestUtils; import org.commcare.formplayer.util.SimpleTimer; import org.commcare.formplayer.util.UserUtils; import org.commcare.formplayer.web.client.WebClient; import org.commcare.modern.database.TableBuilder; import org.javarosa.core.api.ClassNameHasher; import org.javarosa.core.model.User; -import org.javarosa.core.util.PropertyUtils; import org.javarosa.core.util.externalizable.PrototypeFactory; import org.javarosa.xml.util.InvalidStructureException; import org.javarosa.xml.util.UnfullfilledRequirementsException; @@ -56,27 +52,23 @@ import org.xml.sax.SAXException; import org.xmlpull.v1.XmlPullParserException; +import javax.annotation.PreDestroy; +import javax.annotation.Resource; +import javax.xml.parsers.DocumentBuilder; +import javax.xml.parsers.DocumentBuilderFactory; +import javax.xml.parsers.ParserConfigurationException; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.UnsupportedEncodingException; import java.net.URI; import java.sql.SQLException; -import java.time.Duration; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.TimeUnit; -import javax.annotation.PreDestroy; -import javax.annotation.Resource; -import javax.xml.parsers.DocumentBuilder; -import javax.xml.parsers.DocumentBuilderFactory; -import javax.xml.parsers.ParserConfigurationException; - -import datadog.trace.api.Trace; -import io.sentry.SentryLevel; +import static org.commcare.formplayer.util.Constants.TOGGLE_INCLUDE_STATE_HASH; /** @@ -96,8 +88,6 @@ public class RestoreFactory { private String username; private String scrubbedUsername; private String domain; - private HqAuth hqAuth; - private boolean permitAggressiveSyncs = true; public static final String FREQ_DAILY = "freq-daily"; @@ -110,8 +100,6 @@ public class RestoreFactory { private static final String DEVICE_ID_SLUG = "WebAppsLogin"; - private static final String ORIGIN_TOKEN_SLUG = "OriginToken"; - @Autowired protected StatsDClient datadogStatsDClient; @@ -127,27 +115,15 @@ public class RestoreFactory { @Autowired private WebClient webClient; - @Autowired - private RedisTemplate redisTemplateLong; - @Resource(name = "redisTemplateLong") private ValueOperations valueOperations; - @Autowired - private RedisTemplate redisTemplateString; - - @Resource(name = "redisTemplateString") - private ValueOperations originTokens; - @Autowired private RedisTemplate redisSetTemplate; @Resource(name = "redisSetTemplate") private SetOperations redisSessionCache; - @Value("${commcarehq.formplayerAuthKey}") - private String formplayerAuthKey; - private final Log log = LogFactory.getLog(RestoreFactory.class); CategoryTimingHelper.RecordingTimer downloadRestoreTimer; @@ -157,11 +133,10 @@ public class RestoreFactory { private String caseId; private boolean configured = false; - public void configure(String domain, String caseId, HqAuth auth) { + public void configure(String domain, String caseId) { this.setUsername(UserUtils.getRestoreAsCaseIdUsername(caseId)); this.setDomain(domain); this.setCaseId(caseId); - this.setHqAuth(auth); this.hasRestored = false; this.configured = true; sqLiteDB = new UserDB(domain, scrubbedUsername, null); @@ -169,11 +144,10 @@ public void configure(String domain, String caseId, HqAuth auth) { "username = %s, caseId = %s, domain = %s", username, caseId, domain)); } - public void configure(String username, String domain, String asUsername, HqAuth auth) { + public void configure(String username, String domain, String asUsername) { this.setUsername(username); this.setDomain(domain); this.setAsUsername(asUsername); - this.hqAuth = auth; this.hasRestored = false; this.configured = true; sqLiteDB = new UserDB(domain, scrubbedUsername, asUsername); @@ -181,11 +155,10 @@ public void configure(String username, String domain, String asUsername, HqAuth "username = %s, asUsername = %s, domain = %s", username, asUsername, domain)); } - public void configure(AuthenticatedRequestBean authenticatedRequestBean, HqAuth auth) { + public void configure(AuthenticatedRequestBean authenticatedRequestBean) { this.setUsername(authenticatedRequestBean.getUsername()); this.setDomain(authenticatedRequestBean.getDomain()); this.setAsUsername(authenticatedRequestBean.getRestoreAs()); - this.setHqAuth(auth); this.hasRestored = false; this.configured = true; sqLiteDB = new UserDB(domain, scrubbedUsername, asUsername); @@ -446,14 +419,7 @@ public InputStream getRestoreXml(boolean skipFixtures) { } public HttpHeaders getRequestHeaders(URI url) { - HttpHeaders headers; - if (RequestUtils.requestAuthedWithHmac()) { - headers = getHmacHeader(url); - } else { - headers = getHqAuth().getAuthHeaders();; - } - headers.addAll(getStandardHeaders()); - return headers; + return getStandardHeaders(); } private void recordSentryData(final String restoreUrl) { @@ -609,17 +575,9 @@ private HttpHeaders getStandardHeaders() { if (syncToken != null) { headers.set("X-CommCareHQ-LastSyncToken", getSyncToken()); } - headers.setAll(getOriginTokenHeader()); return headers; } - private Map getOriginTokenHeader() { - String originToken = PropertyUtils.genUUID(); - String redisKey = String.format("%s%s", ORIGIN_TOKEN_SLUG, originToken); - originTokens.set(redisKey, "valid", Duration.ofSeconds(60)); - return Collections.singletonMap("X-CommCareHQ-Origin-Token", originToken); - } - public URI getRestoreUrl(boolean skipFixtures) { // TODO: remove timing once the state hash rollout is complete return categoryTimingHelper.timed(Constants.TimingCategories.BUILD_RESTORE_URL, () -> { @@ -630,31 +588,6 @@ public URI getRestoreUrl(boolean skipFixtures) { }); } - private HttpHeaders getHmacHeader(URI url) { - // Do HMAC auth which requires only the path and query components of the URL - String requestPath = url.getRawPath(); - if (url.getRawQuery() != null) { - requestPath = String.format("%s?%s", requestPath, url.getRawQuery()); - } - if (!RequestUtils.requestAuthedWithHmac()) { - throw new RuntimeException(String.format("Tried getting HMAC Auth for request %s but this request" + - "was not validated with HMAC.", requestPath)); - } - String digest; - try { - digest = RequestUtils.getHmac(formplayerAuthKey, requestPath); - } catch (Exception e) { - log.error("Could not get HMAC signature to auth restore request", e); - throw new RuntimeException(e); - } - - return new HttpHeaders() { - { - add("X-MAC-DIGEST", digest); - } - }; - } - public URI getCaseRestoreUrl() { return UriComponentsBuilder.fromHttpUrl(caseRestoreUrl).buildAndExpand(domain, caseId).toUri(); } @@ -685,9 +618,6 @@ public URI getUserRestoreUrl(boolean skipFixtures) { asUserParam += "@" + domain + ".commcarehq.org"; } params.put("as", asUserParam); - } else if (getHqAuth() == null && username != null) { - // HQ requesting to force a sync for a user - params.put("as", username); } if (skipFixtures) { params.put("skip_fixtures", "true"); @@ -783,14 +713,6 @@ public void setDomain(String domain) { this.domain = domain; } - public HqAuth getHqAuth() { - return hqAuth; - } - - public void setHqAuth(HqAuth hqAuth) { - this.hqAuth = hqAuth; - } - public String getAsUsername() { return asUsername; } diff --git a/src/main/java/org/commcare/formplayer/session/FormSession.java b/src/main/java/org/commcare/formplayer/session/FormSession.java index 09586eaa8..1c88d68db 100644 --- a/src/main/java/org/commcare/formplayer/session/FormSession.java +++ b/src/main/java/org/commcare/formplayer/session/FormSession.java @@ -572,6 +572,7 @@ public FormEntryResponseBean answerQuestionToJson(Object answer, String answerIn FormEntryResponseBean.class); if (!session.isInPromptMode() || !Constants.ANSWER_RESPONSE_STATUS_POSITIVE.equals( response.getStatus())) { + response.setSessionId(session.getId()); return response; } return getNextFormNavigation(); @@ -588,6 +589,7 @@ public FormEntryNavigationResponseBean getFormNavigation() throws IOException { responseBean.setTitle(session.getTitle()); responseBean.setCurrentIndex(session.getCurrentIndex()); responseBean.setEvent(responseBean.getTree()[0]); + responseBean.setSessionId(session.getId()); return responseBean; } @@ -611,6 +613,7 @@ public FormEntryNavigationResponseBean getNextFormNavigation() throws IOExceptio responseBean.setInstanceXml(null); responseBean.setTree(null); responseBean.setStatus(Constants.ANSWER_RESPONSE_STATUS_POSITIVE); + responseBean.setSessionId(session.getId()); if (nextEvent == FormEntryController.EVENT_END_OF_FORM) { String output = submitGetXml(); responseBean.getEvent().setOutput(output); diff --git a/src/main/java/org/commcare/formplayer/util/Constants.java b/src/main/java/org/commcare/formplayer/util/Constants.java index ad491a114..7add5263a 100644 --- a/src/main/java/org/commcare/formplayer/util/Constants.java +++ b/src/main/java/org/commcare/formplayer/util/Constants.java @@ -90,10 +90,6 @@ public class Constants { public static final String CCZ_LATEST_SAVED = "save"; // Postgres tables - public static final String POSTGRES_TOKEN_TABLE_NAME = "django_session"; - // Token table generated from django rest framework - public static final String POSTGRES_AUTH_TOKEN_TABLE_NAME = "authtoken_token"; - public static final String POSTGRES_USER_TABLE_NAME = "auth_user"; public static final String POSTGRES_MENU_SESSION_TABLE_NAME = "menu_sessions"; public static final String POSTGRES_VIRTUAL_DATA_INSTANCE_TABLE_NAME = "virtual_data_instance"; @@ -101,9 +97,6 @@ public class Constants { public static final String SESSION_DETAILS_VIEW = "/hq/admin/session_details/"; - // Couch databases - public static final String COUCH_USERS_DB = "__users"; - public static final String POSTGRES_DJANGO_SESSION_ID = "sessionid"; public static final String COMMCARE_USER_SUFFIX = "commcarehq.org"; diff --git a/src/main/java/org/commcare/formplayer/util/RequestUtils.java b/src/main/java/org/commcare/formplayer/util/RequestUtils.java index 115d39cbd..8e6ad8052 100644 --- a/src/main/java/org/commcare/formplayer/util/RequestUtils.java +++ b/src/main/java/org/commcare/formplayer/util/RequestUtils.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.Optional; import javax.crypto.Mac; @@ -80,15 +81,19 @@ public static String getRequestEndpoint() { return request == null ? "unknown" : StringUtils.strip(request.getRequestURI(), "/"); } + public static String getHmac(String key, String data) throws Exception { + return getHmac(key, data.getBytes(StandardCharsets.UTF_8)); + } + /** * Get the HMAC hash of a given request body with a given key Used by Formplayer to validate * requests from HQ using shared internal key `commcarehq.formplayerAuthKey` */ - public static String getHmac(String key, String data) throws Exception { + public static String getHmac(String key, byte[] data) throws Exception { Mac sha256_HMAC = Mac.getInstance("HmacSHA256"); SecretKeySpec secret_key = new SecretKeySpec(key.getBytes("UTF-8"), "HmacSHA256"); sha256_HMAC.init(secret_key); - return Base64.encodeBase64String(sha256_HMAC.doFinal(data.getBytes("UTF-8"))); + return Base64.encodeBase64String(sha256_HMAC.doFinal(data)); } public static HttpServletRequest getCurrentRequest() { diff --git a/src/main/java/org/commcare/formplayer/web/client/CommCareAuthInterceptor.java b/src/main/java/org/commcare/formplayer/web/client/CommCareAuthInterceptor.java new file mode 100644 index 000000000..035a0a1fb --- /dev/null +++ b/src/main/java/org/commcare/formplayer/web/client/CommCareAuthInterceptor.java @@ -0,0 +1,32 @@ +package org.commcare.formplayer.web.client; + +import org.springframework.http.HttpRequest; +import org.springframework.http.client.ClientHttpRequestExecution; +import org.springframework.http.client.ClientHttpRequestInterceptor; +import org.springframework.http.client.ClientHttpResponse; + +import java.io.IOException; +import java.net.URISyntaxException; + +/** + * Rest request interceptor that will modify requests to CommCare with the appropriate authentication details. + */ +public abstract class CommCareAuthInterceptor implements ClientHttpRequestInterceptor { + + private final CommCareRequestFilter requestFilter; + + public CommCareAuthInterceptor(CommCareRequestFilter requestFilter) { + this.requestFilter = requestFilter; + } + + @Override + public ClientHttpResponse intercept(HttpRequest request, byte[] body, + ClientHttpRequestExecution execution) throws IOException { + if (requestFilter.isMatch(request)) { + request = modifyRequest(request, body); + } + return execution.execute(request, body); + } + + protected abstract HttpRequest modifyRequest(HttpRequest request, byte[] body); +} diff --git a/src/main/java/org/commcare/formplayer/web/client/CommCareDefaultHeaders.java b/src/main/java/org/commcare/formplayer/web/client/CommCareDefaultHeaders.java new file mode 100644 index 000000000..c6affb92b --- /dev/null +++ b/src/main/java/org/commcare/formplayer/web/client/CommCareDefaultHeaders.java @@ -0,0 +1,56 @@ +package org.commcare.formplayer.web.client; + +import org.commcare.formplayer.util.RequestUtils; +import org.javarosa.core.util.PropertyUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.web.client.RestTemplateRequestCustomizer; +import org.springframework.data.redis.core.ValueOperations; +import org.springframework.http.client.ClientHttpRequest; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.time.Duration; + + +/** + * Adds default headers to requests to CommCareHQ + */ +@Component +public class CommCareDefaultHeaders implements RestTemplateRequestCustomizer { + + private static final String ORIGIN_TOKEN_SLUG = "OriginToken"; + private final CommCareRequestFilter requestFilter; + + private ValueOperations originTokens; + + @Autowired + public CommCareDefaultHeaders(@Value("${commcarehq.host}") String commcareHost) { + requestFilter = new CommCareRequestFilter(commcareHost); + } + + @Resource(name = "redisTemplateString") + public void setOriginTokens(ValueOperations originTokens) { + this.originTokens = originTokens; + } + + @Override + public void customize(ClientHttpRequest request) { + if (!requestFilter.isMatch(request)) { + return; + } + + String ipAddress = RequestUtils.getIpAddress(); + if (ipAddress != null) { + request.getHeaders().add("X-CommCareHQ-Origin-IP", ipAddress); + } + request.getHeaders().add("X-CommCareHQ-Origin-Token", getOriginTokenHeader()); + } + + private String getOriginTokenHeader() { + String originToken = PropertyUtils.genUUID(); + String redisKey = String.format("%s%s", ORIGIN_TOKEN_SLUG, originToken); + originTokens.set(redisKey, "valid", Duration.ofSeconds(60)); + return originToken; + } +} diff --git a/src/main/java/org/commcare/formplayer/web/client/CommCareHmacRequestFilter.java b/src/main/java/org/commcare/formplayer/web/client/CommCareHmacRequestFilter.java new file mode 100644 index 000000000..366fb5224 --- /dev/null +++ b/src/main/java/org/commcare/formplayer/web/client/CommCareHmacRequestFilter.java @@ -0,0 +1,22 @@ +package org.commcare.formplayer.web.client; + +import org.commcare.formplayer.util.RequestUtils; +import org.springframework.http.HttpRequest; + +/** + * Filter to determine if a request is a request to CommCare. Additionally, filter based on + * whether the current Spring request was authenticated with HMAC or not. + */ +public class CommCareHmacRequestFilter extends CommCareRequestFilter { + + private final boolean matchHmac; + + public CommCareHmacRequestFilter(String commcareHost, boolean matchHmac) { + super(commcareHost); + this.matchHmac = matchHmac; + } + + public boolean isMatch(HttpRequest request) { + return super.isMatch(request) && matchHmac == RequestUtils.requestAuthedWithHmac(); + } +} diff --git a/src/main/java/org/commcare/formplayer/web/client/CommCareRequestFilter.java b/src/main/java/org/commcare/formplayer/web/client/CommCareRequestFilter.java new file mode 100644 index 000000000..fe6a24dc5 --- /dev/null +++ b/src/main/java/org/commcare/formplayer/web/client/CommCareRequestFilter.java @@ -0,0 +1,22 @@ +package org.commcare.formplayer.web.client; + +import org.commcare.formplayer.util.RequestUtils; +import org.springframework.http.HttpRequest; + +import java.util.function.Predicate; + +/** + * Filter to determine if a request is a request to CommCare. + */ +public class CommCareRequestFilter { + + private final String commcareHost; + + public CommCareRequestFilter(String commcareHost) { + this.commcareHost = commcareHost; + } + + public boolean isMatch(HttpRequest request) { + return request.getURI().toString().startsWith(commcareHost); + } +} diff --git a/src/main/java/org/commcare/formplayer/web/client/HmacAuthInterceptor.java b/src/main/java/org/commcare/formplayer/web/client/HmacAuthInterceptor.java new file mode 100644 index 000000000..54412506e --- /dev/null +++ b/src/main/java/org/commcare/formplayer/web/client/HmacAuthInterceptor.java @@ -0,0 +1,88 @@ +package org.commcare.formplayer.web.client; + +import lombok.extern.apachecommons.CommonsLog; +import okhttp3.HttpUrl; +import org.commcare.formplayer.beans.auth.HqUserDetailsBean; +import org.commcare.formplayer.util.Constants; +import org.commcare.formplayer.util.RequestUtils; +import org.jetbrains.annotations.NotNull; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.*; + +/** + * Rest request interceptor that will add HMAC headers to requests that require it + */ +@CommonsLog +public class HmacAuthInterceptor extends CommCareAuthInterceptor { + + private final String formplayerAuthKey; + + public HmacAuthInterceptor(CommCareRequestFilter requestFilter, String formplayerAuthKey) { + super(requestFilter); + this.formplayerAuthKey = formplayerAuthKey; + } + + @Override + protected HttpRequest modifyRequest(HttpRequest request, byte[] body) { + request = addAsParamIfNotPresent(request); + HttpHeaders hmacHeaders = getHmacHeaders(request, body); + request.getHeaders().addAll(hmacHeaders); + return request; + } + + /** + * HMAC requests require the 'as' parameter to be present in the request URI + * since it is not possible to determine the user from the authentication header. + */ + @NotNull + private static HttpRequest addAsParamIfNotPresent(HttpRequest request) { + URI uri = request.getURI(); + + HttpUrl httpUrl = HttpUrl.get(uri); + boolean asParamMissing = httpUrl != null && !httpUrl.queryParameterNames().contains("as"); + Optional userDetails = RequestUtils.getUserDetails(); + if (asParamMissing && userDetails.isPresent()) { + String asParamValue = userDetails.get().getUsername(); + URI newUri = httpUrl.newBuilder().addQueryParameter("as", asParamValue).build().uri(); + request = new ReplaceUriHttpRequest(newUri, request); + log.warn(String.format("HMAC request augmented with 'as=%s' param", asParamValue)); + } + return request; + } + + private HttpHeaders getHmacHeaders(HttpRequest request, byte[] body) { + try { + return switch (Objects.requireNonNull(request.getMethod())) { + case GET -> getHmacHeaderForGetRequest(request.getURI()); + case POST -> getHmacHeader(body); + default -> throw new RuntimeException("Unsupported HTTP method: " + request.getMethod()); + }; + } catch (Exception e) { + log.error("Could not get HMAC signature", e); + throw new RuntimeException(e); + } + } + + private HttpHeaders getHmacHeaderForGetRequest(URI url) throws Exception { + // Do HMAC auth which requires only the path and query components of the URL + String requestPath = url.getRawPath(); + if (url.getRawQuery() != null) { + requestPath = String.format("%s?%s", requestPath, url.getRawQuery()); + } + + return getHmacHeader(requestPath.getBytes(StandardCharsets.UTF_8)); + } + + private HttpHeaders getHmacHeader(byte[] data) throws Exception { + String digest = RequestUtils.getHmac(formplayerAuthKey, data); + return new HttpHeaders() { + { + add(Constants.HMAC_HEADER, digest); + } + }; + } +} diff --git a/src/main/java/org/commcare/formplayer/web/client/RestTemplateConfig.java b/src/main/java/org/commcare/formplayer/web/client/RestTemplateConfig.java index 6dbcf2d65..a2c3eb0f1 100644 --- a/src/main/java/org/commcare/formplayer/web/client/RestTemplateConfig.java +++ b/src/main/java/org/commcare/formplayer/web/client/RestTemplateConfig.java @@ -1,6 +1,7 @@ package org.commcare.formplayer.web.client; import org.commcare.formplayer.util.Constants; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.context.annotation.Bean; @@ -19,22 +20,26 @@ public class RestTemplateConfig { public static String MODE_REPLACE_HOST = "replace-host"; + private CommCareDefaultHeaders commCareDefaultHeaders; + @Value("${formplayer.externalRequestMode}") private String externalRequestMode; @Value("${commcarehq.host}") private String commcareHost; + @Value("${commcarehq.formplayerAuthKey}") + private String formplayerAuthKey; + public RestTemplateConfig() { } /** * Constructor for tests - * - * @param externalRequestMode */ - public RestTemplateConfig(String commcareHost, String externalRequestMode) { + public RestTemplateConfig(String commcareHost, String formplayerAuthKey, String externalRequestMode) { this.commcareHost = commcareHost; + this.formplayerAuthKey = formplayerAuthKey; this.externalRequestMode = externalRequestMode; } @@ -43,13 +48,25 @@ public RestTemplate restTemplate(RestTemplateBuilder builder) throws URISyntaxEx builder = builder .setConnectTimeout(Duration.ofMillis(Constants.CONNECT_TIMEOUT)) .setReadTimeout(Duration.ofMillis(Constants.READ_TIMEOUT)) - .requestFactory(OkHttp3ClientHttpRequestFactory.class); + .requestFactory(OkHttp3ClientHttpRequestFactory.class) + .additionalRequestCustomizers(commCareDefaultHeaders); if (externalRequestMode.equals(MODE_REPLACE_HOST)) { log.warn(String.format("RestTemplate configured in '%s' mode", externalRequestMode)); builder = builder.additionalInterceptors( new RewriteHostRequestInterceptor(commcareHost)); } - return builder.build(); + + CommCareRequestFilter hmacAuthFilter = new CommCareHmacRequestFilter(commcareHost, true); + CommCareRequestFilter sessionAuthFilter = new CommCareHmacRequestFilter(commcareHost, false); + return builder.additionalInterceptors( + new HmacAuthInterceptor(hmacAuthFilter, formplayerAuthKey), + new SessionAuthInterceptor(sessionAuthFilter) + ).build(); + } + + @Autowired + public void setCommCareDefaultHeaders(CommCareDefaultHeaders commCareDefaultHeaders) { + this.commCareDefaultHeaders = commCareDefaultHeaders; } } diff --git a/src/main/java/org/commcare/formplayer/web/client/SessionAuthInterceptor.java b/src/main/java/org/commcare/formplayer/web/client/SessionAuthInterceptor.java new file mode 100644 index 000000000..7263c94b3 --- /dev/null +++ b/src/main/java/org/commcare/formplayer/web/client/SessionAuthInterceptor.java @@ -0,0 +1,38 @@ +package org.commcare.formplayer.web.client; + +import lombok.extern.apachecommons.CommonsLog; +import org.commcare.formplayer.util.Constants; +import org.commcare.formplayer.util.RequestUtils; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpRequest; + +/** + * Rest request interceptor that will add Django session auth headers to requests that require it + */ +@CommonsLog +public class SessionAuthInterceptor extends CommCareAuthInterceptor { + + public SessionAuthInterceptor(CommCareRequestFilter requestFilter) { + super(requestFilter); + } + + @Override + protected HttpRequest modifyRequest(HttpRequest request, byte[] body) { + HttpHeaders sessionHeaders = getSessionHeaders(); + request.getHeaders().addAll(sessionHeaders); + return request; + } + + public HttpHeaders getSessionHeaders() { + HttpHeaders headers = new HttpHeaders(); + RequestUtils.getUserDetails().ifPresent(userDetails -> { + String authToken = userDetails.getAuthToken(); + String auth = Constants.POSTGRES_DJANGO_SESSION_ID + "=" + authToken; + headers.add("Cookie", auth); + headers.add(Constants.POSTGRES_DJANGO_SESSION_ID, authToken); + headers.add("Authorization", auth); + }); + return headers; + } + +} diff --git a/src/main/java/org/commcare/formplayer/web/client/WebClient.java b/src/main/java/org/commcare/formplayer/web/client/WebClient.java index 09df99cb3..54dabb1e1 100644 --- a/src/main/java/org/commcare/formplayer/web/client/WebClient.java +++ b/src/main/java/org/commcare/formplayer/web/client/WebClient.java @@ -1,23 +1,16 @@ package org.commcare.formplayer.web.client; import com.google.common.collect.Multimap; - +import lombok.extern.apachecommons.CommonsLog; import org.commcare.formplayer.services.RestoreFactory; -import org.commcare.formplayer.util.RequestUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.RequestEntity; -import org.springframework.http.ResponseEntity; +import org.springframework.http.*; import org.springframework.stereotype.Component; import org.springframework.util.LinkedMultiValueMap; import org.springframework.web.client.RestTemplate; import java.net.URI; -import lombok.extern.apachecommons.CommonsLog; - @Component @CommonsLog public class WebClient { @@ -58,21 +51,15 @@ public String post(String url, T body) { } public String post(String url, T body, boolean isMultipart) { - checkHmac(); URI uri = URI.create(url); HttpHeaders headers = restoreFactory.getRequestHeaders(uri); if (isMultipart) { headers.setContentType(MediaType.MULTIPART_FORM_DATA); } - String ipAddress = RequestUtils.getIpAddress(); - if (ipAddress != null) { - headers.add("X-CommCareHQ-Origin-IP", ipAddress); - } return postRaw(uri, headers, body, String.class).getBody(); } public String postFormData(String url, Multimap data) { - checkHmac(); URI uri = URI.create(url); LinkedMultiValueMap postData = new LinkedMultiValueMap<>(); data.forEach(postData::add); @@ -83,7 +70,6 @@ public String postFormData(String url, Multimap data) { } public Boolean caseClaimPost(String url, T body) { - checkHmac(); URI uri = URI.create(url); ResponseEntity entity = postRaw(uri, restoreFactory.getRequestHeaders(uri), body, String.class); Boolean shouldSync = true; @@ -109,16 +95,6 @@ public ResponseEntity postRaw(URI uri, HttpHeaders headers, T body, Cl return response; } - /** - * This is not a technical limitation, just a code limitation that should be fixed in the - * future. - */ - private void checkHmac() { - if (RequestUtils.requestAuthedWithHmac()) { - throw new RuntimeException("HMAC auth not supported for POST requests"); - } - } - @Autowired public void setRestoreFactory(RestoreFactory restoreFactory) { this.restoreFactory = restoreFactory; diff --git a/src/test/java/org/commcare/formplayer/auth/MockMultipartController.java b/src/test/java/org/commcare/formplayer/auth/MockMultipartController.java index b5544ae7a..5c21fe889 100644 --- a/src/test/java/org/commcare/formplayer/auth/MockMultipartController.java +++ b/src/test/java/org/commcare/formplayer/auth/MockMultipartController.java @@ -34,8 +34,7 @@ public class MockMultipartController { @ConfigureStorageFromSession public void multipartRequest( @RequestPart(PART_FILE) MultipartFile file, - @RequestPart(PART_ANSWER) SessionRequestBean sessionRequestBean, - @CookieValue(name = Constants.POSTGRES_DJANGO_SESSION_ID, required = false) String authToken) { + @RequestPart(PART_ANSWER) SessionRequestBean sessionRequestBean) { return; } } diff --git a/src/test/java/org/commcare/formplayer/auth/SessionAuthTests.java b/src/test/java/org/commcare/formplayer/auth/SessionAuthTests.java index 5c9a5a25e..657dc1422 100644 --- a/src/test/java/org/commcare/formplayer/auth/SessionAuthTests.java +++ b/src/test/java/org/commcare/formplayer/auth/SessionAuthTests.java @@ -111,7 +111,7 @@ public void testMultipartEndpointWithFullAuth_WithAnyHmacAuth_Succeeds() throws private void mockValidAuth(String sessionId) { TokenMatcher matcher = new TokenMatcher(DOMAIN, USERNAME, sessionId); when(userDetailsService.loadUserDetails(argThat(matcher))).thenReturn( - new HqUserDetailsBean(DOMAIN, USERNAME) + new HqUserDetailsBean(sessionId, DOMAIN, USERNAME) ); } diff --git a/src/test/java/org/commcare/formplayer/junit/Installer.kt b/src/test/java/org/commcare/formplayer/junit/Installer.kt index 2870d50f1..1c197b35f 100644 --- a/src/test/java/org/commcare/formplayer/junit/Installer.kt +++ b/src/test/java/org/commcare/formplayer/junit/Installer.kt @@ -5,7 +5,6 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.ObjectNode import org.apache.commons.logging.LogFactory import org.assertj.core.api.Assertions -import org.commcare.formplayer.auth.DjangoAuth import org.commcare.formplayer.beans.InstallRequestBean import org.commcare.formplayer.beans.menus.CommandListResponseBean import org.commcare.formplayer.services.FormplayerStorageFactory @@ -50,7 +49,7 @@ class Installer( ) val bean = refAndBean.second storageFactory.configure(bean) - restoreFactory.configure(bean, DjangoAuth("key")) + restoreFactory.configure(bean) if (bean.isMustRestore) { restoreFactory.performTimedSync() } diff --git a/src/test/java/org/commcare/formplayer/junit/RestoreFactoryExtension.kt b/src/test/java/org/commcare/formplayer/junit/RestoreFactoryExtension.kt index f07599431..7f882e81c 100644 --- a/src/test/java/org/commcare/formplayer/junit/RestoreFactoryExtension.kt +++ b/src/test/java/org/commcare/formplayer/junit/RestoreFactoryExtension.kt @@ -1,9 +1,7 @@ package org.commcare.formplayer.junit -import org.commcare.formplayer.auth.DjangoAuth import org.commcare.formplayer.services.RestoreFactory import org.junit.jupiter.api.extension.AfterEachCallback -import org.junit.jupiter.api.extension.BeforeAllCallback import org.junit.jupiter.api.extension.BeforeEachCallback import org.junit.jupiter.api.extension.ExtensionContext import org.mockito.ArgumentMatchers @@ -61,7 +59,7 @@ class RestoreFactoryExtension( override fun beforeEach(context: ExtensionContext) { restoreFactory = SpringExtension.getApplicationContext(context).getBean(RestoreFactory::class.java) reset() - restoreFactory.configure(username, domain, asUser, DjangoAuth("test")) + restoreFactory.configure(username, domain, asUser) configureMock() } diff --git a/src/test/java/org/commcare/formplayer/junit/request/AnswerMediaQuestionRequest.kt b/src/test/java/org/commcare/formplayer/junit/request/AnswerMediaQuestionRequest.kt index 608887b31..cb1b30d40 100644 --- a/src/test/java/org/commcare/formplayer/junit/request/AnswerMediaQuestionRequest.kt +++ b/src/test/java/org/commcare/formplayer/junit/request/AnswerMediaQuestionRequest.kt @@ -49,7 +49,6 @@ class AnswerMediaQuestionRequest( multipart("/" + Constants.URL_ANSWER_MEDIA_QUESTION) .file(file) .part(answer) - .cookie(Cookie(Constants.POSTGRES_DJANGO_SESSION_ID, "derp")) .with(SecurityMockMvcRequestPostProcessors.csrf()) .with(SecurityMockMvcRequestPostProcessors.user("user")) ).andExpect(MockMvcResultMatchers.status().isOk) diff --git a/src/test/java/org/commcare/formplayer/junit/request/MockRequest.kt b/src/test/java/org/commcare/formplayer/junit/request/MockRequest.kt index a67c3af49..19fe771b7 100644 --- a/src/test/java/org/commcare/formplayer/junit/request/MockRequest.kt +++ b/src/test/java/org/commcare/formplayer/junit/request/MockRequest.kt @@ -36,9 +36,8 @@ open class MockRequest( open fun getRequestBuilder(requestPath: String, requestBean: B): MockHttpServletRequestBuilder { return post(requestPath) .contentType(MediaType.APPLICATION_JSON) - .cookie(Cookie(Constants.POSTGRES_DJANGO_SESSION_ID, "derp")) .with(SecurityMockMvcRequestPostProcessors.csrf()) - .with(SecurityMockMvcRequestPostProcessors.user(HqUserDetailsBean("domain", "user"))) + .with(SecurityMockMvcRequestPostProcessors.user(HqUserDetailsBean("derp", "domain", "user"))) .content(mapper.writeValueAsString(requestBean)) } } diff --git a/src/test/java/org/commcare/formplayer/tests/BaseTestClass.java b/src/test/java/org/commcare/formplayer/tests/BaseTestClass.java index d74f3b1d1..2406f91f2 100644 --- a/src/test/java/org/commcare/formplayer/tests/BaseTestClass.java +++ b/src/test/java/org/commcare/formplayer/tests/BaseTestClass.java @@ -26,7 +26,6 @@ import org.commcare.formplayer.application.MenuController; import org.commcare.formplayer.application.SQLiteProperties; import org.commcare.formplayer.application.UtilController; -import org.commcare.formplayer.auth.DjangoAuth; import org.commcare.formplayer.beans.AnswerQuestionRequestBean; import org.commcare.formplayer.beans.AuthenticatedRequestBean; import org.commcare.formplayer.beans.ChangeLocaleRequestBean; @@ -545,7 +544,7 @@ NewFormResponse startNewForm(String requestPath, String formPath) throws Excepti String requestPayload = FileUtils.getFile(this.getClass(), requestPath); NewSessionRequestBean newSessionRequestBean = mapper.readValue(requestPayload, NewSessionRequestBean.class); - restoreFactoryMock.configure(newSessionRequestBean, new DjangoAuth("derp")); + restoreFactoryMock.configure(newSessionRequestBean); return new NewFormRequest(mockFormController, webClientMock, formPath) .requestWithBean(newSessionRequestBean) .bean(); @@ -559,7 +558,7 @@ SubmitResponseBean submitForm(String requestPath, String sessionId) throws Excep SubmitRequestBean submitRequestBean = mapper.readValue (FileUtils.getFile(this.getClass(), requestPath), SubmitRequestBean.class); submitRequestBean.setSessionId(sessionId); - restoreFactoryMock.configure(submitRequestBean, new DjangoAuth("123")); + restoreFactoryMock.configure(submitRequestBean); return new SubmitFormRequest(mockFormSubmissionController) .requestWithBean(submitRequestBean).bean(); } @@ -575,7 +574,7 @@ SubmitResponseBean submitForm(Map answers, String sessionId, sessionId); submitRequestBean.setAnswers(answers); submitRequestBean.setPrevalidated(prevalidated); - restoreFactoryMock.configure(submitRequestBean, new DjangoAuth("123")); + restoreFactoryMock.configure(submitRequestBean); return new SubmitFormRequest(mockFormSubmissionController) .requestWithBean(submitRequestBean).bean(); } @@ -585,7 +584,7 @@ protected SyncDbResponseBean syncDb() { syncDbRequestBean.setDomain(restoreFactoryMock.getDomain()); syncDbRequestBean.setUsername(restoreFactoryMock.getUsername()); syncDbRequestBean.setRestoreAs(restoreFactoryMock.getAsUsername()); - restoreFactoryMock.configure(syncDbRequestBean, new DjangoAuth("derp")); + restoreFactoryMock.configure(syncDbRequestBean); return new SyncDbRequest(mockUtilController, restoreFactoryMock).requestWithBean(syncDbRequestBean).bean(); } @@ -962,7 +961,7 @@ private T generateMockQuery(ControllerType controllerType, if (bean instanceof AuthenticatedRequestBean) { restoreFactoryMock.getSQLiteDB().closeConnection(); - restoreFactoryMock.configure((AuthenticatedRequestBean)bean, new DjangoAuth("derp")); + restoreFactoryMock.configure((AuthenticatedRequestBean)bean); } if (bean instanceof InstallRequestBean) { @@ -1000,7 +999,6 @@ private T generateMockQuery(ControllerType controllerType, result = controller.perform( post(urlPrepend(urlPath)) .contentType(MediaType.APPLICATION_JSON) - .cookie(new Cookie(Constants.POSTGRES_DJANGO_SESSION_ID, "derp")) .content((String)bean)); break; @@ -1008,7 +1006,6 @@ private T generateMockQuery(ControllerType controllerType, result = controller.perform( get(urlPrepend(urlPath)) .contentType(MediaType.APPLICATION_JSON) - .cookie(new Cookie(Constants.POSTGRES_DJANGO_SESSION_ID, "derp")) .content((String)bean)); break; } diff --git a/src/test/java/org/commcare/formplayer/tests/CaseClaimTests.java b/src/test/java/org/commcare/formplayer/tests/CaseClaimTests.java index 3a4f387b3..29cbcfc8a 100644 --- a/src/test/java/org/commcare/formplayer/tests/CaseClaimTests.java +++ b/src/test/java/org/commcare/formplayer/tests/CaseClaimTests.java @@ -1,30 +1,10 @@ package org.commcare.formplayer.tests; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - import com.google.common.collect.Multimap; - import org.commcare.cases.model.Case; import org.commcare.formplayer.beans.NewFormResponse; import org.commcare.formplayer.beans.SubmitResponseBean; -import org.commcare.formplayer.beans.menus.CommandListResponseBean; -import org.commcare.formplayer.beans.menus.EntityDetailListResponse; -import org.commcare.formplayer.beans.menus.EntityDetailResponse; -import org.commcare.formplayer.beans.menus.EntityListResponse; -import org.commcare.formplayer.beans.menus.QueryResponseBean; +import org.commcare.formplayer.beans.menus.*; import org.commcare.formplayer.junit.RestoreFactoryAnswer; import org.commcare.formplayer.objects.QueryData; import org.commcare.formplayer.sandbox.SqlStorage; @@ -45,10 +25,13 @@ import org.springframework.cache.CacheManager; import org.springframework.cache.caffeine.CaffeineCache; +import javax.annotation.Nullable; import java.util.HashMap; import java.util.Hashtable; -import javax.annotation.Nullable; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; /** * Regression tests for fixed behaviors @@ -751,13 +734,14 @@ public void testSplitScreenResponse() throws Exception { EntityListResponse.class); assertNull(responseBean.getQueryResponse(), "Query response attached to entity response when split screen is disabled"); - WithHqUserSecurityContextFactory.setSecurityContext( + try (AutoCloseable __ = WithHqUserSecurityContextFactory.setSecurityContext( HqUserDetails.builder().enabledToggles(new String[]{"SPLIT_SCREEN_CASE_SEARCH"}).build() - ); - responseBean = sessionNavigateWithQuery(new String[]{"1", "action 1"}, - "caseclaim", - null, - EntityListResponse.class); + )) { + responseBean = sessionNavigateWithQuery(new String[]{"1", "action 1"}, + "caseclaim", + null, + EntityListResponse.class); + } assertNotNull(responseBean.getQueryResponse(), "No query response attached to entity response when split screen is enabled"); } diff --git a/src/test/java/org/commcare/formplayer/tests/CsrfIntegrationTest.java b/src/test/java/org/commcare/formplayer/tests/CsrfIntegrationTest.java index 9c63326a4..63d35c106 100644 --- a/src/test/java/org/commcare/formplayer/tests/CsrfIntegrationTest.java +++ b/src/test/java/org/commcare/formplayer/tests/CsrfIntegrationTest.java @@ -113,7 +113,6 @@ public void postApiCall_withoutCsrf_fails() throws Exception { mockUtilController.perform( post("/" + Constants.URL_DELETE_APPLICATION_DBS) .contentType(MediaType.APPLICATION_JSON) - .cookie(new Cookie(Constants.POSTGRES_DJANGO_SESSION_ID, "derp")) .content(payload) ).andExpect(status().isForbidden()); } @@ -123,7 +122,6 @@ public void getApiCall_withoutCsrf_succeeds() throws Exception { mockUtilController.perform( get("/" + Constants.URL_SERVER_UP) .contentType(MediaType.APPLICATION_JSON) - .cookie(new Cookie(Constants.POSTGRES_DJANGO_SESSION_ID, "derp")) ).andExpect(status().isOk()); } @@ -133,7 +131,6 @@ public void postApiCall_withHmacHeader_withoutCsrf_succeeds() throws Exception { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); - headers.add("Cookie", Constants.POSTGRES_DJANGO_SESSION_ID + "=" + "derp"); headers.add(Constants.HMAC_HEADER, "BHOwo3mPXbtWM91RO0g5HQOt+DtiiQVnCWMFsvjkWVc="); HttpEntity entity = new HttpEntity<>(payload, headers); diff --git a/src/test/java/org/commcare/formplayer/tests/FormplayerDatadogTests.java b/src/test/java/org/commcare/formplayer/tests/FormplayerDatadogTests.java index d985a0bab..deb342398 100644 --- a/src/test/java/org/commcare/formplayer/tests/FormplayerDatadogTests.java +++ b/src/test/java/org/commcare/formplayer/tests/FormplayerDatadogTests.java @@ -1,11 +1,6 @@ package org.commcare.formplayer.tests; -import static org.commcare.formplayer.util.Constants.TOGGLE_DETAILED_TAGGING; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - import com.timgroup.statsd.StatsDClient; - import org.commcare.formplayer.util.FormplayerDatadog; import org.commcare.formplayer.utils.HqUserDetails; import org.commcare.formplayer.utils.WithHqUserSecurityContextFactory; @@ -18,6 +13,10 @@ import java.util.Collections; import java.util.List; +import static org.commcare.formplayer.util.Constants.TOGGLE_DETAILED_TAGGING; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + public class FormplayerDatadogTests { FormplayerDatadog datadog; StatsDClient mockDatadogClient; @@ -103,15 +102,16 @@ public void testTransientTagUsedForRecordExecutionTime() { } @Test - public void testAddRequestScopedDetailedTagForEligibleDomain() { - enableDetailTagging(); - - // detailed_tag was added to FormplayerDatadog in beforeTest - datadog.addRequestScopedTag("detailed_tag", "test_value"); - datadog.recordExecutionTime("requests", 100, Collections.emptyList()); - String expectedTag = "detailed_tag:test_value"; - String[] args = {expectedTag}; - verify(mockDatadogClient).recordExecutionTime("requests", 100, args); + public void testAddRequestScopedDetailedTagForEligibleDomain() throws Exception { + try (AutoCloseable __ = enableDetailTagging()) { + + // detailed_tag was added to FormplayerDatadog in beforeTest + datadog.addRequestScopedTag("detailed_tag", "test_value"); + datadog.recordExecutionTime("requests", 100, Collections.emptyList()); + String expectedTag = "detailed_tag:test_value"; + String[] args = {expectedTag}; + verify(mockDatadogClient).recordExecutionTime("requests", 100, args); + } } @Test @@ -125,16 +125,16 @@ public void testAddRequestScopedDetailedTagForIneligibleDomain() { } @Test - public void testAddTransientDetailedTagForEligibleDomain() { - enableDetailTagging(); - - List transientTags = new ArrayList<>(); - // detailed_tag was added to FormplayerDatadog in beforeTest - transientTags.add(new FormplayerDatadog.Tag("detailed_tag", "test_value")); - datadog.recordExecutionTime("requests", 100, transientTags); - String expectedTag = "detailed_tag:test_value"; - String[] args = {expectedTag}; - verify(mockDatadogClient).recordExecutionTime("requests", 100, args); + public void testAddTransientDetailedTagForEligibleDomain() throws Exception { + try (AutoCloseable __ = enableDetailTagging()) { + List transientTags = new ArrayList<>(); + // detailed_tag was added to FormplayerDatadog in beforeTest + transientTags.add(new FormplayerDatadog.Tag("detailed_tag", "test_value")); + datadog.recordExecutionTime("requests", 100, transientTags); + String expectedTag = "detailed_tag:test_value"; + String[] args = {expectedTag}; + verify(mockDatadogClient).recordExecutionTime("requests", 100, args); + } } @Test @@ -148,8 +148,8 @@ public void testAddTransientDetailedTagForIneligibleDomain() { verify(mockDatadogClient).recordExecutionTime("requests", 100, args); } - private void enableDetailTagging() { - WithHqUserSecurityContextFactory.setSecurityContext( + private AutoCloseable enableDetailTagging() { + return WithHqUserSecurityContextFactory.setSecurityContext( HqUserDetails.builder().enabledToggles(new String[]{TOGGLE_DETAILED_TAGGING}).build() ); } diff --git a/src/test/java/org/commcare/formplayer/tests/HqUserDetailsServiceTests.java b/src/test/java/org/commcare/formplayer/tests/HqUserDetailsServiceTests.java index 5d0582323..89f6e5dee 100644 --- a/src/test/java/org/commcare/formplayer/tests/HqUserDetailsServiceTests.java +++ b/src/test/java/org/commcare/formplayer/tests/HqUserDetailsServiceTests.java @@ -84,7 +84,7 @@ public void whenCallingGetUserDetails_thenClientMakesCorrectCall() "\"domains\":[\"domain\"]," + "\"djangoUserId\":1," + "\"username\":\"user@domain.commcarehq.org\"," + - "\"authToken\":\"authToke\"," + + "\"authToken\":\"authToken\"," + "\"superUser\":false" + "}"; @@ -98,6 +98,7 @@ public void whenCallingGetUserDetails_thenClientMakesCorrectCall() assertThat(details.getUsername()).isEqualTo("user@domain.commcarehq.org"); assertThat(details.getDomains()).isEqualTo(new String[]{"domain"}); + assertThat(details.getAuthToken()).isEqualTo("authToken"); } @Test diff --git a/src/test/java/org/commcare/formplayer/tests/HqUserDetailsTests.java b/src/test/java/org/commcare/formplayer/tests/HqUserDetailsTests.java index dd086490d..f14f0facc 100644 --- a/src/test/java/org/commcare/formplayer/tests/HqUserDetailsTests.java +++ b/src/test/java/org/commcare/formplayer/tests/HqUserDetailsTests.java @@ -9,6 +9,8 @@ import org.junit.jupiter.api.Test; import org.springframework.security.core.context.SecurityContextHolder; +import java.io.Closeable; + public class HqUserDetailsTests { @Test @@ -38,23 +40,25 @@ public void testCommCareUserIsAuthorized() { } @Test - public void testFeatureFlagChecker_isToggleEnabled() { - WithHqUserSecurityContextFactory.setSecurityContext( + public void testFeatureFlagChecker_isToggleEnabled() throws Exception { + try (AutoCloseable __ = WithHqUserSecurityContextFactory.setSecurityContext( HqUserDetails.builder().enabledToggles(new String[]{"toggle_a", "toggle_b"}).build() - ); - Assertions.assertTrue(FeatureFlagChecker.isToggleEnabled("toggle_a")); - Assertions.assertTrue(FeatureFlagChecker.isToggleEnabled("toggle_b")); - Assertions.assertFalse(FeatureFlagChecker.isToggleEnabled("toggle_c")); + )) { + Assertions.assertTrue(FeatureFlagChecker.isToggleEnabled("toggle_a")); + Assertions.assertTrue(FeatureFlagChecker.isToggleEnabled("toggle_b")); + Assertions.assertFalse(FeatureFlagChecker.isToggleEnabled("toggle_c")); + } } @Test - public void testFeatureFlagChecker_isPreviewEnabled() { - WithHqUserSecurityContextFactory.setSecurityContext( + public void testFeatureFlagChecker_isPreviewEnabled() throws Exception { + try (AutoCloseable __ = WithHqUserSecurityContextFactory.setSecurityContext( HqUserDetails.builder().enabledPreviews(new String[]{"preview_a", "preview_b"}).build() - ); - Assertions.assertTrue(FeatureFlagChecker.isPreviewEnabled("preview_a")); - Assertions.assertTrue(FeatureFlagChecker.isPreviewEnabled("preview_b")); - Assertions.assertFalse(FeatureFlagChecker.isPreviewEnabled("preview_c")); + )) { + Assertions.assertTrue(FeatureFlagChecker.isPreviewEnabled("preview_a")); + Assertions.assertTrue(FeatureFlagChecker.isPreviewEnabled("preview_b")); + Assertions.assertFalse(FeatureFlagChecker.isPreviewEnabled("preview_c")); + } } @AfterEach diff --git a/src/test/java/org/commcare/formplayer/tests/RestoreFactoryTest.java b/src/test/java/org/commcare/formplayer/tests/RestoreFactoryTest.java index 1ebe1a21f..ac21229b9 100644 --- a/src/test/java/org/commcare/formplayer/tests/RestoreFactoryTest.java +++ b/src/test/java/org/commcare/formplayer/tests/RestoreFactoryTest.java @@ -1,22 +1,9 @@ package org.commcare.formplayer.tests; -import static org.commcare.formplayer.util.Constants.TOGGLE_INCLUDE_STATE_HASH; -import static org.hamcrest.Matchers.hasEntry; -import static org.hamcrest.core.IsEqual.equalTo; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.when; - -import static java.util.Collections.singletonList; - import org.commcare.cases.util.CaseDBUtils; -import org.commcare.formplayer.auth.DjangoAuth; import org.commcare.formplayer.beans.AuthenticatedRequestBean; import org.commcare.formplayer.configuration.CacheConfiguration; import org.commcare.formplayer.services.RestoreFactory; -import org.commcare.formplayer.util.Constants; -import org.commcare.formplayer.util.RequestUtils; import org.commcare.formplayer.utils.TestContext; import org.commcare.formplayer.utils.WithHqUser; import org.hamcrest.Description; @@ -26,69 +13,50 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.Mock; +import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.MockedStatic; import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest; import org.springframework.http.HttpHeaders; import org.springframework.test.context.ContextConfiguration; -import org.springframework.web.context.request.RequestContextHolder; -import org.springframework.web.context.request.ServletRequestAttributes; -import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.UUID; -import javax.servlet.http.HttpServletRequest; +import static java.util.Collections.singletonList; +import static org.commcare.formplayer.util.Constants.TOGGLE_INCLUDE_STATE_HASH; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; /** * Created by benrudolph on 1/19/17. */ @WebMvcTest @ContextConfiguration(classes = {TestContext.class, CacheConfiguration.class}) +@ExtendWith(MockitoExtension.class) public class RestoreFactoryTest { private static final String BASE_URL = "http://localhost:8000/a/restore-domain/phone/restore/"; - private String username = "restore-dude"; - private String domain = "restore-domain"; - private String asUsername = "restore-gal"; - - @Value("${commcarehq.formplayerAuthKey}") - private String formplayerAuthKey; @Autowired RestoreFactory restoreFactorySpy; - @Mock - private ServletRequestAttributes requestAttributes; - - @Mock - private HttpServletRequest request; - @BeforeEach public void setUp() throws Exception { - MockitoAnnotations.openMocks(this); Mockito.reset(restoreFactorySpy); AuthenticatedRequestBean requestBean = new AuthenticatedRequestBean(); - requestBean.setRestoreAs(asUsername); - requestBean.setUsername(username); - requestBean.setDomain(domain); - restoreFactorySpy.configure(requestBean, new DjangoAuth("key")); + requestBean.setRestoreAs("restore-gal"); + requestBean.setUsername("restore-dude"); + requestBean.setDomain("restore-domain"); + restoreFactorySpy.configure(requestBean); restoreFactorySpy.setAsUsername(null); restoreFactorySpy.setCaseId(null); - - // mock request - RequestContextHolder.setRequestAttributes(requestAttributes); - when(requestAttributes.getRequest()).thenReturn(request); - } - - private void mockHmacRequest() { - when(request.getAttribute(eq(Constants.HMAC_REQUEST_ATTRIBUTE))).thenReturn(true); } private void mockSyncFreq(String freq) { @@ -248,66 +216,11 @@ public void testGetRequestHeaders() { String syncToken = "synctoken"; Mockito.doReturn(syncToken).when(restoreFactorySpy).getSyncToken(); HttpHeaders headers = restoreFactorySpy.getRequestHeaders(null); - assertEquals(7, headers.size()); - validateHeaders(headers, Arrays.asList( - hasEntry("Cookie", singletonList("sessionid=key")), - hasEntry("sessionid", singletonList("key")), - hasEntry("Authorization", singletonList("sessionid=key")), - hasEntry("X-OpenRosa-Version", singletonList("3.0")), - hasEntry("X-OpenRosa-DeviceId", singletonList("WebAppsLogin")), - hasEntry("X-CommCareHQ-LastSyncToken", singletonList(syncToken)), - hasEntry(equalTo("X-CommCareHQ-Origin-Token"), new ValueIsUUID())) - ); - } - - @Test - public void testGetRequestHeaders_HmacAuth() throws Exception { - mockHmacRequest(); - restoreFactorySpy.configure(domain, "case_id", null); - String requestPath = "/a/restore-domain/phone/case_restore/case_id_123/"; - HttpHeaders headers = restoreFactorySpy.getRequestHeaders( - new URI("http://localhost:8000" + requestPath)); - assertEquals(4, headers.size()); - validateHeaders(headers, Arrays.asList( - hasEntry("X-MAC-DIGEST", - singletonList(RequestUtils.getHmac(formplayerAuthKey, requestPath))), - hasEntry("X-OpenRosa-Version", singletonList("3.0")), - hasEntry("X-OpenRosa-DeviceId", singletonList("WebAppsLogin")), - hasEntry(equalTo("X-CommCareHQ-Origin-Token"), new ValueIsUUID())) - ); - } - - @Test - public void testGetRequestHeaders_HmacAuth_UrlWithQuery() throws Exception { - mockHmacRequest(); - restoreFactorySpy.configure(domain, "case_id", null); - String requestPath = - "/a/restore-domain/phone/case_restore/case_id_123/?query_param=true"; - HttpHeaders headers = restoreFactorySpy.getRequestHeaders( - new URI("http://localhost:8000" + requestPath)); - assertEquals(4, headers.size()); - validateHeaders(headers, Arrays.asList( - hasEntry("X-MAC-DIGEST", - singletonList(RequestUtils.getHmac(formplayerAuthKey, requestPath))), - hasEntry("X-OpenRosa-Version", singletonList("3.0")), - hasEntry("X-OpenRosa-DeviceId", singletonList("WebAppsLogin")), - hasEntry(equalTo("X-CommCareHQ-Origin-Token"), new ValueIsUUID())) - ); - } - - @Test - public void testGetRequestHeaders_UseHmacAuthEvenIfHqAuthPresent() throws Exception { - mockHmacRequest(); - String requestPath = "/a/restore-domain/phone/case_restore/case_id_123/"; - HttpHeaders headers = restoreFactorySpy.getRequestHeaders( - new URI("http://localhost:8000" + requestPath)); - assertEquals(4, headers.size()); + assertEquals(3, headers.size()); validateHeaders(headers, Arrays.asList( - hasEntry("X-MAC-DIGEST", - singletonList(RequestUtils.getHmac(formplayerAuthKey, requestPath))), hasEntry("X-OpenRosa-Version", singletonList("3.0")), hasEntry("X-OpenRosa-DeviceId", singletonList("WebAppsLogin")), - hasEntry(equalTo("X-CommCareHQ-Origin-Token"), new ValueIsUUID())) + hasEntry("X-CommCareHQ-LastSyncToken", singletonList(syncToken))) ); } diff --git a/src/test/java/org/commcare/formplayer/utils/HqUserDetails.java b/src/test/java/org/commcare/formplayer/utils/HqUserDetails.java index 204abe815..452019caa 100644 --- a/src/test/java/org/commcare/formplayer/utils/HqUserDetails.java +++ b/src/test/java/org/commcare/formplayer/utils/HqUserDetails.java @@ -30,6 +30,7 @@ public class HqUserDetails { private boolean isSuperUser; private String[] enabledPreviews; private String[] enabledToggles; + private String authToken; public HqUserDetails(WithHqUser withUser) { String username = StringUtils.hasLength(withUser.username()) ? withUser.username() @@ -42,9 +43,12 @@ public HqUserDetails(WithHqUser withUser) { this.isSuperUser = withUser.isSuperUser(); this.enabledPreviews = withUser.enabledPreviews(); this.enabledToggles = withUser.enabledToggles(); + this.authToken = withUser.authToken(); } public HqUserDetailsBean toBean() { - return new HqUserDetailsBean(domain, domains, username, isSuperUser, enabledToggles, enabledPreviews); + HqUserDetailsBean bean = new HqUserDetailsBean(domain, domains, username, isSuperUser, enabledToggles, enabledPreviews); + bean.setAuthToken(this.authToken); + return bean; } } diff --git a/src/test/java/org/commcare/formplayer/utils/MockRestTemplateBuilder.java b/src/test/java/org/commcare/formplayer/utils/MockRestTemplateBuilder.java new file mode 100644 index 000000000..6d1babdef --- /dev/null +++ b/src/test/java/org/commcare/formplayer/utils/MockRestTemplateBuilder.java @@ -0,0 +1,43 @@ +package org.commcare.formplayer.utils; + +import org.commcare.formplayer.web.client.CommCareDefaultHeaders; +import org.commcare.formplayer.web.client.RestTemplateConfig; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.springframework.boot.web.client.RestTemplateBuilder; +import org.springframework.data.redis.core.ValueOperations; +import org.springframework.web.client.RestTemplate; + +import java.net.URISyntaxException; + +public class MockRestTemplateBuilder { + + private String commcareHost = ""; + private String formpayerAuthKey = ""; + private String externalRequestMode = ""; + + private ValueOperations originTokens = Mockito.mock(ValueOperations.class); + + public MockRestTemplateBuilder withCommcareHost(String commcareHost) { + this.commcareHost = commcareHost; + return this; + } + + public MockRestTemplateBuilder withFormpayerAuthKey(String formpayerAuthKey) { + this.formpayerAuthKey = formpayerAuthKey; + return this; + } + + public MockRestTemplateBuilder withExternalRequestMode(String externalRequestMode) { + this.externalRequestMode = externalRequestMode; + return this; + } + + public RestTemplate getRestTemplate() throws URISyntaxException { + RestTemplateConfig config = new RestTemplateConfig(commcareHost, formpayerAuthKey, externalRequestMode); + CommCareDefaultHeaders commCareDefaultHeaders = new CommCareDefaultHeaders(commcareHost); + commCareDefaultHeaders.setOriginTokens(originTokens); + config.setCommCareDefaultHeaders(commCareDefaultHeaders); + return config.restTemplate(new RestTemplateBuilder()); + } +} diff --git a/src/test/java/org/commcare/formplayer/utils/WithHqUser.java b/src/test/java/org/commcare/formplayer/utils/WithHqUser.java index 1b3a4ff1b..c9ba56f08 100644 --- a/src/test/java/org/commcare/formplayer/utils/WithHqUser.java +++ b/src/test/java/org/commcare/formplayer/utils/WithHqUser.java @@ -92,4 +92,9 @@ * List of enabled toggles for the user. Defaults to a mock list of toggle_a and toggle_b */ String[] enabledToggles() default {"toggle_a", "toggle_b"}; + + /** + * The auth token to use. Defaults to null. + */ + String authToken() default ""; } diff --git a/src/test/java/org/commcare/formplayer/utils/WithHqUserSecurityContextFactory.java b/src/test/java/org/commcare/formplayer/utils/WithHqUserSecurityContextFactory.java index 830886c1b..a2351136e 100644 --- a/src/test/java/org/commcare/formplayer/utils/WithHqUserSecurityContextFactory.java +++ b/src/test/java/org/commcare/formplayer/utils/WithHqUserSecurityContextFactory.java @@ -7,6 +7,8 @@ import org.springframework.security.test.context.support.WithSecurityContextFactory; import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken; +import java.io.Closeable; + /** * A WithSecurityContextFactory that works with {@link WithHqUser}. * @@ -28,8 +30,9 @@ public static SecurityContext createSecurityContext(HqUserDetails details) { return context; } - public static void setSecurityContext(HqUserDetails details) { + public static AutoCloseable setSecurityContext(HqUserDetails details) { SecurityContext context = createSecurityContext(details); SecurityContextHolder.setContext(context); + return SecurityContextHolder::clearContext; } } diff --git a/src/test/java/org/commcare/formplayer/web/client/RestTemplateAuthTest.java b/src/test/java/org/commcare/formplayer/web/client/RestTemplateAuthTest.java new file mode 100644 index 000000000..166f055ce --- /dev/null +++ b/src/test/java/org/commcare/formplayer/web/client/RestTemplateAuthTest.java @@ -0,0 +1,213 @@ +package org.commcare.formplayer.web.client; + +import org.commcare.formplayer.util.Constants; +import org.commcare.formplayer.util.RequestUtils; +import org.commcare.formplayer.utils.HqUserDetails; +import org.commcare.formplayer.utils.MockRestTemplateBuilder; +import org.commcare.formplayer.utils.WithHqUserSecurityContextFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.data.redis.core.ValueOperations; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.test.web.client.*; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; + +import javax.servlet.http.HttpServletRequest; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.*; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withStatus; + +@ExtendWith(MockitoExtension.class) +class RestTemplateAuthTest { + + public static final String AUTH_TOKEN = "123abc"; + private RestTemplate restTemplate; + + private MockRestServiceServer mockServer; + + @Mock + private ServletRequestAttributes requestAttributes; + + @Mock + private HttpServletRequest request; + + private final String commcareHost = "https://www.commcarehq.org"; + private final String formplayerAuthKey = "authKey"; + private AutoCloseable securityContext; + + + @BeforeEach + public void init() throws URISyntaxException { + restTemplate = new MockRestTemplateBuilder() + .withCommcareHost(commcareHost) + .withFormpayerAuthKey(formplayerAuthKey) + .getRestTemplate(); + mockServer = MockRestServiceServer.createServer(restTemplate); + RequestContextHolder.setRequestAttributes(requestAttributes); + securityContext = WithHqUserSecurityContextFactory.setSecurityContext( + HqUserDetails.builder().username("testUser").authToken(AUTH_TOKEN).build() + ); + } + + @AfterEach + public void tearDown() throws Exception { + securityContext.close(); + RequestContextHolder.resetRequestAttributes(); + } + + private void mockGetRequest() { + lenient().when(requestAttributes.getRequest()).thenReturn(request); + } + + private void mockHmacRequest() { + mockGetRequest(); + lenient().when(request.getAttribute(eq(Constants.HMAC_REQUEST_ATTRIBUTE))).thenReturn(true); + } + + @Test + public void testRestTemplateSessionAuth() throws URISyntaxException { + mockGetRequest(); + + String url = commcareHost + "/a/demo/receiver/1234"; + + expectRequest(url, HttpMethod.GET) + .andExpect(sessionAuth()) + .andRespond(response()); + + restTemplate.getForObject(url, String.class); + mockServer.verify(); + } + + @Test + public void testRestTemplateSessionAuth_notForCommCare() throws URISyntaxException { + String url = "https://www.otherhost.com/a/demo/receiver/1234"; + expectRequest(url, HttpMethod.GET) + .andExpect(noAuth()) + .andRespond(response()); + + restTemplate.getForObject(url, String.class); + mockServer.verify(); + } + + @Test + public void testRestTemplateHmacAuth_GET() throws Exception { + mockHmacRequest(); + + String url = "/a/demo/receiver/1234?a=1&b=2&as=testUser"; + String authHeader = RequestUtils.getHmac(formplayerAuthKey, url.getBytes(StandardCharsets.UTF_8)); + + String fullUrl = commcareHost + url; + expectRequest(fullUrl, HttpMethod.GET) + .andExpect(hmacAuth(authHeader)) + .andRespond(response()); + + restTemplate.getForObject(fullUrl, String.class); + mockServer.verify(); + } + + @Test + public void testRestTemplateHmacAuth_POST() throws Exception { + mockHmacRequest(); + + String body = "This is the POST body"; + String authHeader = RequestUtils.getHmac(formplayerAuthKey, body.getBytes(StandardCharsets.UTF_8)); + + String url = commcareHost + "/a/demo/receiver/1234?a=1&b=2&as=testUser"; + expectRequest(url, HttpMethod.POST) + .andExpect(hmacAuth(authHeader)) + .andRespond(response()); + + restTemplate.postForObject(url, body, String.class); + mockServer.verify(); + } + + @Test + public void testRestTemplateHmacAuth_addsAsParam() throws Exception { + mockHmacRequest(); + + String url = "/a/demo/receiver/1234?a=1&b=2"; + String expectedUrl = url + "&as=testUser"; + String authHeader = RequestUtils.getHmac(formplayerAuthKey, expectedUrl.getBytes(StandardCharsets.UTF_8)); + + expectRequest(commcareHost + expectedUrl, HttpMethod.GET) + .andExpect(hmacAuth(authHeader)) + .andRespond(response()); + + restTemplate.getForObject(commcareHost + url, String.class); + mockServer.verify(); + } + + @Test + public void testRestTemplateHmacAuth_notCommCare() throws Exception { + mockHmacRequest(); + + String url = "http://localhost/a/demo/receiver/1234?a=1&b=2"; + expectRequest(url, HttpMethod.GET) + .andExpect(noAuth()) + .andRespond(response()); + + restTemplate.getForObject(url, String.class); + mockServer.verify(); + } + + private ResponseActions expectRequest(String url, HttpMethod method) throws URISyntaxException { + return mockServer.expect(ExpectedCount.once(), requestTo(new URI(url))) + .andExpect(method(method)); + } + + private RequestMatcher hmacAuth(String authHeader) { + return compoundMatcher( + header(Constants.HMAC_HEADER, authHeader), + headerDoesNotExist(Constants.POSTGRES_DJANGO_SESSION_ID), + headerDoesNotExist("Authorization"), + headerDoesNotExist("Cookie") + ); + } + + private RequestMatcher sessionAuth() { + String authHeader = Constants.POSTGRES_DJANGO_SESSION_ID + "=" + AUTH_TOKEN; + return compoundMatcher( + header(Constants.POSTGRES_DJANGO_SESSION_ID, AUTH_TOKEN), + header("Cookie", authHeader), + header("Authorization", authHeader), + headerDoesNotExist(Constants.HMAC_HEADER) + ); + } + + private RequestMatcher noAuth() { + return compoundMatcher( + headerDoesNotExist(Constants.HMAC_HEADER), + headerDoesNotExist(Constants.POSTGRES_DJANGO_SESSION_ID), + headerDoesNotExist("Authorization"), + headerDoesNotExist("Cookie") + ); + } + + private ResponseCreator response() { + return withStatus(HttpStatus.OK) + .contentType(MediaType.TEXT_HTML) + .body("response"); + } + + private RequestMatcher compoundMatcher(RequestMatcher... matchers) { + return request -> { + for (RequestMatcher matcher : matchers) { + matcher.match(request); + } + }; + } + +} diff --git a/src/test/java/org/commcare/formplayer/web/client/RestTemplateConfigTest.java b/src/test/java/org/commcare/formplayer/web/client/RestTemplateReplaceHostTest.java similarity index 72% rename from src/test/java/org/commcare/formplayer/web/client/RestTemplateConfigTest.java rename to src/test/java/org/commcare/formplayer/web/client/RestTemplateReplaceHostTest.java index 25bb807ca..2b2f0ff51 100644 --- a/src/test/java/org/commcare/formplayer/web/client/RestTemplateConfigTest.java +++ b/src/test/java/org/commcare/formplayer/web/client/RestTemplateReplaceHostTest.java @@ -1,12 +1,12 @@ package org.commcare.formplayer.web.client; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; -import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; -import static org.springframework.test.web.client.response.MockRestResponseCreators.withStatus; - +import org.commcare.formplayer.utils.MockRestTemplateBuilder; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.boot.web.client.RestTemplateBuilder; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.data.redis.core.ValueOperations; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; @@ -17,6 +17,11 @@ import java.net.URI; import java.net.URISyntaxException; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.method; +import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo; +import static org.springframework.test.web.client.response.MockRestResponseCreators.withStatus; + +@ExtendWith(MockitoExtension.class) class RestTemplateConfigTest_noCustomization { private RestTemplate restTemplate; @@ -25,13 +30,13 @@ class RestTemplateConfigTest_noCustomization { @BeforeEach public void init() throws URISyntaxException { - restTemplate = getRestTemplate("https://web", ""); + restTemplate = getRestTemplate("https://web"); mockServer = MockRestServiceServer.createServer(restTemplate); } - protected RestTemplate getRestTemplate(String commcareHost, String mode) + protected RestTemplate getRestTemplate(String commcareHost) throws URISyntaxException { - return new RestTemplateConfig(commcareHost, mode).restTemplate(new RestTemplateBuilder()); + return new MockRestTemplateBuilder().withCommcareHost(commcareHost).getRestTemplate(); } protected String getExpectedUrl() { @@ -56,9 +61,12 @@ public void testRestTemplate() throws URISyntaxException { class RestTemplateConfigTest_replaceHost extends RestTemplateConfigTest_noCustomization { @Override - public RestTemplate getRestTemplate(String commcareHost, String mode) + public RestTemplate getRestTemplate(String commcareHost) throws URISyntaxException { - return super.getRestTemplate(commcareHost, RestTemplateConfig.MODE_REPLACE_HOST); + return new MockRestTemplateBuilder() + .withCommcareHost(commcareHost) + .withExternalRequestMode(RestTemplateConfig.MODE_REPLACE_HOST) + .getRestTemplate(); } @Override diff --git a/src/test/java/org/commcare/formplayer/web/client/WebClientTest.java b/src/test/java/org/commcare/formplayer/web/client/WebClientTest.java index 5c960fa40..394c21d40 100644 --- a/src/test/java/org/commcare/formplayer/web/client/WebClientTest.java +++ b/src/test/java/org/commcare/formplayer/web/client/WebClientTest.java @@ -10,11 +10,13 @@ import com.google.common.collect.ImmutableListMultimap; import org.commcare.formplayer.services.RestoreFactory; +import org.commcare.formplayer.utils.MockRestTemplateBuilder; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -28,10 +30,9 @@ import java.net.URISyntaxException; +@ExtendWith(MockitoExtension.class) public class WebClientTest { - private RestTemplate restTemplate; - private MockRestServiceServer mockServer; private WebClient webClient; @@ -41,10 +42,7 @@ public class WebClientTest { @BeforeEach public void init() throws URISyntaxException { - MockitoAnnotations.openMocks(this); - - RestTemplateConfig config = new RestTemplateConfig("", ""); - restTemplate = config.restTemplate(new RestTemplateBuilder()); + RestTemplate restTemplate = new MockRestTemplateBuilder().getRestTemplate(); mockServer = MockRestServiceServer.createServer(restTemplate); webClient = new WebClient();