diff --git a/httprest/api.py b/httprest/api.py index 5990c82..302e926 100644 --- a/httprest/api.py +++ b/httprest/api.py @@ -1,5 +1,6 @@ """API.""" +import urllib.parse from typing import Optional as _Optional from typing import Union as _Union @@ -31,7 +32,7 @@ def __init__( def _request( self, method: str, - endpoint: str, + endpoint: _Optional[str] = None, data: _Optional[_Union[dict, bytes]] = None, json: _Optional[dict] = None, headers: _Optional[dict] = None, @@ -42,7 +43,7 @@ def _request( # pylint: disable=too-many-arguments """Make API request. - :param endpoint: API endpoint. Will be appended to the base URL + :param endpoint: API endpoint. Will be joined with the base URL Other parameters are the same as for the `HTTPClient.request` method """ @@ -57,8 +58,8 @@ def _request( cert=cert, ) - def _build_url(self, endpoint: str) -> str: - return f"{self._base_url}/{endpoint.strip('/')}" + def _build_url(self, endpoint: _Optional[str]) -> str: + return urllib.parse.urljoin(self._base_url, endpoint) def __str__(self) -> str: return f"{self.__class__.__name__}(base_url='{self._base_url}')" diff --git a/tests/unit/fakes.py b/tests/unit/fakes.py index 95a4057..d3668d1 100644 --- a/tests/unit/fakes.py +++ b/tests/unit/fakes.py @@ -4,16 +4,19 @@ from typing import Optional from httprest import API +from httprest.http import HTTPResponse from httprest.http.fake_client import FakeHTTPClient class _TestAPI(API): """API client for tests.""" - def make_call(self): + def make_call( + self, endpoint: Optional[str] = "/example/endpoint/" + ) -> HTTPResponse: """Make API call.""" return self._request( - "POST", "/example/endpoint/", json={"k": "v"}, headers={"h": "v"} + "POST", endpoint, json={"k": "v"}, headers={"h": "v"} ) diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 7dbfe40..c0f4146 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,5 +1,9 @@ """Tests for the API client.""" +import typing + +import pytest + from httprest.http import HTTPResponse from .fakes import FakeHTTPClient, build_api @@ -19,6 +23,21 @@ def test_api_call(): "headers": {"h": "v"}, "json": {"k": "v"}, "method": "POST", - "url": "http://fake.com/example/endpoint", + "url": "http://fake.com/example/endpoint/", }, ] + + +@pytest.mark.parametrize("endpoint", ["", None]) +def test_url_without_endpoint(endpoint: typing.Optional[str]): + """Test for request URL when endpoint is not specified. + + The base URL must be used. + """ + base = "http://fake.com" + comps = build_api( + base_url=base, + http_client=FakeHTTPClient(responses=[HTTPResponse(200, b"", {})]), + ) + comps.api.make_call(endpoint=endpoint) + assert comps.http_client.history[0]["url"] == base