diff --git a/README.md b/README.md index c539a2f..979f949 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,15 @@ For OSDU on AWS, this client is useful in the case where you may want to perform - add_group_member - delete_group_member - create_group +- [legal](osdu/services/legal.py) + - get_legaltag + - create_legaltag + - delete_legaltag + - get_legaltags + - update_legaltag + - batch_retrive_legaltags + - validate_legaltags + - get_legaltag_properties ## Installation @@ -190,6 +199,21 @@ osdu_client = AwsOsduClient(data_partition, profile=profile) ``` +### Automatically re-authorizing the client +Each client will automatically attempt to re-authorize when its access token expires. In order for this re-authorization to succeed, you will need to supply the client with additional parameters (either through environment variables or in their consructor): + +#### Simple Client: +1. OSDU_CLIENTWITHSECRET_ID +1. OSDU_CLIENTWITHSECRET_SECRET +1. REFRESH_TOKEN +1. REFRESH_URL + +#### AWS Client: +1. OSDU_PASSWORD (in the environment variables, or somewhere else it can persist securely) + +#### Service Principal: +N/A--this client can re-authorize with just the variables needed for it to instantiate + ### Using the client Below are just a few usage examples. See [integration tests](https://github.com/pariveda/osdupy/blob/master/tests/tests_integration.py) for more comprehensive usage examples. diff --git a/osdu/client/_base.py b/osdu/client/_base.py index fd613c8..772eb3b 100644 --- a/osdu/client/_base.py +++ b/osdu/client/_base.py @@ -2,17 +2,19 @@ """ import os - +from time import time from ..services.search import SearchService from ..services.storage import StorageService from ..services.dataset import DatasetService from ..services.entitlements import EntitlementsService +from ..services.legal import LegalService class BaseOsduClient: @property def access_token(self): + self._ensure_valid_token() return self._access_token @property @@ -38,6 +40,10 @@ def delivery(self): @property def dataset(self): return self._dataset + + @property + def legal(self): + return self.__legal @property def data_partition_id(self): @@ -64,6 +70,31 @@ def __init__(self, data_partition_id, api_url: str = None): self._storage = StorageService(self) self._dataset = DatasetService(self) self._entitlements = EntitlementsService(self) - # TODO: Implement these services. - # self.__legal = LegaService(self) + self.__legal = LegalService(self) + + def _need_update_token(self): + return hasattr(self, "_token_expiration") and self._token_expiration < time() or self._access_token is None + + def _ensure_valid_token(self): + """Determines if the current access token associated with the client has expired. + If the token is not expired, the current access_token will be returned, unchanged. + If the token has expired, this function will attempt to refresh it, update it on client, and return it. + For simple clients, refresh requires a OSDU_CLIENTWITHSECRET_ID, OSDU_CLIENTWITHSECRET_SECRET, REFRESH_TOKEN, and REFRESH_URL + For Service Principal clients, refresh requires a resource_prefix and AWS_PROFILE (same as initial auth) + For AWS clients, refresh requires OSDU_USER, OSDU_PASSWORD, AWS_PROFILE, and OSDU_CLIENT_ID + + :param client: client in use + + :returns: tuple containing 2 items: the new access token and it's expiration time + - access_token: used to access OSDU services + - expires_in: expiration time for the token + """ + if(self._need_update_token()): + token = self._update_token() + else: + token = self._access_token, self._token_expiration if hasattr(self, "_token_expiration") else None + return token + + def _update_token(self): + pass #each client has their own update_token method diff --git a/osdu/client/_service_principal_util.py b/osdu/client/_service_principal_util.py index 778ee85..0e6b64d 100644 --- a/osdu/client/_service_principal_util.py +++ b/osdu/client/_service_principal_util.py @@ -26,8 +26,11 @@ # - Refactored _get_secret method to fix UnboundLocalError for local variable 'secret'. # - Refactored _get_secret method to simplify try/except flow and to print secret_name on exception. # - Updated formatting to be PEP8-compliant. -# +# 2022-03-16 johnny.reichman@parivedasolutions.com +# - Updated to return the token expiration in addition to the token +# - Added a more descriptive exception check after the POST request import base64 +from time import time import boto3 import requests import json @@ -113,7 +116,8 @@ def get_service_principal_token(self, resource_prefix): token_url = '{}?grant_type=client_credentials&client_id={}&scope={}'.format( token_url, client_id, aws_oauth_custom_scope) - + response = requests.post(url=token_url, headers=headers) - - return json.loads(response.content.decode())['access_token'] + response.raise_for_status() + response_json = json.loads(response.content.decode()) + return response_json['access_token'], response_json['expires_in'] + time() diff --git a/osdu/client/aws.py b/osdu/client/aws.py index 266b7d1..321258f 100644 --- a/osdu/client/aws.py +++ b/osdu/client/aws.py @@ -1,4 +1,5 @@ import os +from time import time import boto3 from ._base import BaseOsduClient @@ -71,6 +72,17 @@ def get_tokens(self, password, secret_hash) -> None: AuthParameters=auth_params ) - + self._token_expiration = response['AuthenticationResult']['ExpiresIn'] + time() self._access_token = response['AuthenticationResult']['AccessToken'] self._refresh_token = response['AuthenticationResult']['RefreshToken'] + + # TODO: refresh can only be used if password is in environment variables. Is there another way to store the password securely? + def _update_token(self): + password = os.environ.get('OSDU_PASSWORD') + if(password): + self.get_tokens(password, self._secret_hash) + password = None + return self._access_token, self._token_expiration + else: + raise Exception('Expired or invalid access token. OSDU_PASSWORD env variable must be set for token to be auto refreshed.') + diff --git a/osdu/client/aws_service_principal.py b/osdu/client/aws_service_principal.py index 4d12f67..c16630c 100644 --- a/osdu/client/aws_service_principal.py +++ b/osdu/client/aws_service_principal.py @@ -4,12 +4,21 @@ class AwsServicePrincipalOsduClient(BaseOsduClient): + @property + def resource_prefix(self): + return self._resource_prefix + def __init__(self, data_partition_id: str, resource_prefix: str, profile: str = None, region: str = None): self._sp_util = ServicePrincipalUtil( resource_prefix, profile=profile, region=region) self._resource_prefix = resource_prefix - self._access_token = self._get_tokens() + self._access_token,self._token_expiration = self._get_tokens() + super().__init__(data_partition_id, self._sp_util.api_url) def _get_tokens(self): return self._sp_util.get_service_principal_token(self._resource_prefix) + + def _update_token(self): + self._access_token, self._token_expiration = self._sp_util.get_service_principal_token(self._resource_prefix) + return self._access_token, self._token_expiration diff --git a/osdu/client/simple.py b/osdu/client/simple.py index 70cfe92..0426413 100644 --- a/osdu/client/simple.py +++ b/osdu/client/simple.py @@ -1,3 +1,6 @@ +import os +import requests +from time import time from ._base import BaseOsduClient @@ -6,15 +9,39 @@ class SimpleOsduClient(BaseOsduClient): This client assumes you are obtaining a token yourself (e.g. via your application's login form or otheer mechanism. With this SimpleOsduClient, you simply provide that token. - With this simplicity, you are also then respnsible for reefreeshing the token as needed and - re-instantiating the client with the new token. + With this simplicity, you are also then respnsible for refreshing the token as needed either by manually + re-instantiating the client with the new token or by providing the authentication client id, secret, refresh token, and refresh url + and allowing the client to attempt the refresh automatically. """ - def __init__(self, data_partition_id: str, access_token: str, api_url: str=None) -> None: + def __init__(self, data_partition_id: str, access_token: str=None, api_url: str=None, refresh_token: str=None, refresh_url: str=None) -> None: """ :param: access_token: The access token only (not including the 'Bearer ' prefix). :param: api_url: must be only the base URL, e.g. https://myapi.myregion.mydomain.com + :param: refresh_token: The refresh token only (not including the 'Bearer ' prefix). + :param: refresh_url: The authentication Url, typically a Cognito URL ending in "/token". """ super().__init__(data_partition_id, api_url) - self._access_token = access_token \ No newline at end of file + self._access_token = access_token + self._refresh_token = refresh_token or os.environ.get('OSDU_REFRESH_TOKEN') + self._refresh_url = refresh_url or os.environ.get('OSDU_REFRESH_URL') + self._client_id = os.environ.get('OSDU_CLIENTWITHSECRET_ID') + self._client_secret = os.environ.get('OSDU_CLIENTWITHSECRET_SECRET') + + def _update_token(self) -> dict: + if not self._refresh_token or not self._refresh_url: + raise Exception('Expired or invalid access token. Both \'refresh_token\' and \'refresh_url\' must be set for token to be auto refreshed.') + + data = {'grant_type': 'refresh_token', + 'client_id': self._client_id, + 'client_secret': self._client_secret, + 'refresh_token': self._refresh_token, + 'scope': 'openid email'} + headers = {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + response = requests.post(url=self._refresh_url,headers=headers, data=data) + response.raise_for_status() + self._access_token = response.json()["access_token"] + self._token_expiration = response.json()["expires_in"] + time() + return self._access_token, self._token_expiration \ No newline at end of file diff --git a/osdu/services/base.py b/osdu/services/base.py index bfd626a..d0077cd 100644 --- a/osdu/services/base.py +++ b/osdu/services/base.py @@ -4,9 +4,10 @@ def __init__(self, client, service_name: str, service_version: int): self._client = client self._service_url = f'{self._client.api_url}/api/{service_name}/v{service_version}' + def _headers(self): return { "Content-Type": "application/json", "data-partition-id": self._client._data_partition_id, "Authorization": "Bearer " + self._client.access_token - } + } \ No newline at end of file diff --git a/osdu/services/legal.py b/osdu/services/legal.py new file mode 100644 index 0000000..28a0dc0 --- /dev/null +++ b/osdu/services/legal.py @@ -0,0 +1,98 @@ +""" Provides a simple Python interface to the OSDU Legal API. +""" +from typing import List +import requests +from .base import BaseService + + +class LegalService(BaseService): + + def __init__(self, client): + super().__init__(client, service_name='legal', service_version=1) + + def get_legaltag(self, legaltag_name: str): + """Returns information about the given legaltag. + + param legaltag_name: the name of the legaltag of interest + """ + url = f'{self._service_url}/legaltags/{legaltag_name}' + response = self.__execute_request('get', url) + + return response.json() + + def create_legaltag(self, legaltag: dict): + """Create a new legaltag. + + param legaltag: a JSON representation of a legaltag + """ + url = f'{self._service_url}/legaltags' + response = self.__execute_request('post', url, json=legaltag) + + return response.json() + + def delete_legaltag(self, legaltag_name: str) -> bool: + """Deletes the given legaltag. This operation cannot be reverted (except by re-creating the legaltag). + + :param legaltag_name: the name of the legaltag to delete + :returns: True if legaltag deleted successfully. Otherwise False. + """ + url = f'{self._service_url}/legaltags/{legaltag_name}' + response = self.__execute_request('delete', url) + + return response.status_code == 204 + + def get_legaltags(self, valid: bool = True): + """Fetches all matching legaltags. + + :param valid: Boolean to restrict results to only valid legaltags (true) or only invalid legal tags (false). Default is true + """ + url = f'{self._service_url}/legaltags/' + ('?valid=true' if valid else '?valid=false') + response = self.__execute_request('get', url) + + return response.json() + + def update_legaltag(self, legaltag: dict): + """Updates a legaltag. Empty properties are ignored, not deleted. + + :param legaltag: dictionary of properties to add/change to an existing legaltag + """ + url = f'{self._service_url}/legaltags' + response = self.__execute_request('put', url, json=legaltag) + + return response.json() + + def batch_retrive_legaltags(self, legaltag_names: List[str]): + """Retrieves information about a list of legaltags + + :param legaltag_names: List of legaltag names to fetch information about + """ + url = f'{self._service_url}/legaltags:batchRetrieve' + payload = {'names': legaltag_names} + response = self.__execute_request('post', url, json=payload) + + return response.json() + + def validate_legaltags(self, legaltag_names: List[str]): + """Validates the given legaltags--returning a list of which legaltags are invalid. + + :param legaltag_names: List of legaltag names to validate + """ + url = f'{self._service_url}/legaltags:validate' + payload = {'names': legaltag_names} + response = self.__execute_request('post', url, json=payload) + + return response.json() + + def get_legaltag_properties(self): + """Fetch information about possible values for legaltag properties""" + url = f'{self._service_url}/legaltags:properties' + response = self.__execute_request('get', url) + + return response.json() + + def __execute_request(self, method: str, url: str, json=None): + headers = self._headers() + response = requests.request(method, url, headers=headers, json=json) + response.raise_for_status() + + return response diff --git a/tests/integration.py b/tests/integration.py index 9d7c221..a2dec4a 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -1,7 +1,14 @@ -""" In order to run these tests, you must provide an appropriate `user` and `password`. The password -can be set locally by setting the environment variable OSDU_PASSWORD. If using +""" In order to run these tests, you must provide appropriate environment variables. If using VS Code, then you can set this in your local `.env` file in your workspace directory to easily switch between OSDU environments. +Most Integration tests require: + OSDU_USER + AWS_PROFILE +SimpleClient update_token integration test require: + OSDU_CLIENTWITHSECRET_ID + OSDU_CLIENTWITHSECRET_SECRET + OSDU_REFRESH_URL + OSDU_REFRESH_TOKEN """ import json import os @@ -33,6 +40,18 @@ def test_endpoint_access(self): result = client.search.query(query)['results'] self.assertEqual(1, len(result)) + + def test_update_token(self): + query = { + "kind": f"*:*:*:*", + "limit": 1 + } + client = SimpleOsduClient(data_partition) + old_access_token = client._access_token + client._token_expiration = 0 #change token expiration so we force an update + updated_access_token = client.access_token + self.assertIsNotNone(updated_access_token) + self.assertNotEqual(old_access_token,updated_access_token) class TestAwsOsduClient(TestCase): @@ -41,6 +60,14 @@ def test_get_access_token(self): client = AwsOsduClient(data_partition) self.assertIsNotNone(client.access_token) + def test_update_token(self): + client = AwsOsduClient(data_partition) + old_access_token = client.access_token + client._token_expiration = 0 # change the token expiration so we force a refresh + updated_access_token = client.access_token + self.assertIsNotNone(updated_access_token) + self.assertNotEqual(old_access_token,updated_access_token) + class TestAwsServicePrincipalOsduClient(TestCase): @@ -70,6 +97,18 @@ def test_endpoint_access(self): self.assertEqual(1, len(result)) + def test_update_token(self): + client = AwsServicePrincipalOsduClient(data_partition, + os.environ['OSDU_RESOURCE_PREFIX'], + profile=os.environ['AWS_PROFILE'], + region=os.environ['AWS_DEFAULT_REGION'] + ) + old_access_token = client.access_token + client._token_expiration = 0 # change the token expiration so we force a refresh + updated_access_token = client.access_token + self.assertIsNotNone(updated_access_token) + self.assertNotEqual(old_access_token,updated_access_token) + class TestOsduServiceBase(TestCase): @@ -285,3 +324,62 @@ def tearDownClass(cls): for record_id in cls.test_records: cls.osdu.storage.purge_record(record_id) super().tearDownClass() + +class TestLegalService(TestOsduServiceBase): + + def test_get_legaltags(self): + result = self.osdu.legal.get_legaltags() + + self.assertTrue(len(result['legalTags']) > 0) + + def test_validate_legaltags(self): + legaltag_names = ["osdu-public-usa-dataset", "osdu-testing-legal-tag-plz-delete"] + result = self.osdu.legal.validate_legaltags(legaltag_names) + + self.assertIsNotNone(result['invalidLegalTags']) + + def test_get_legaltag_properties(self): + result = self.osdu.legal.get_legaltag_properties() + + self.assertIsNotNone(result['dataTypes']) + +class TestLegalService_WithSideEffects(TestOsduServiceBase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + create_legaltag_data_file = 'tests/test_data/test_create_legaltag.json' + with open(create_legaltag_data_file, 'r') as _file: + cls.legaltag_to_create = json.load(_file) + + update_legaltag_data_file = 'tests/test_data/test_update_legaltag.json' + with open(update_legaltag_data_file, 'r') as _file: + cls.legaltag_to_update = json.load(_file) + + def test_001_create_legaltag(self): + result = self.osdu.legal.create_legaltag(self.legaltag_to_create) + + self.assertIsNotNone(result["name"]) + + def test_002_get_legaltag(self): + legaltag = self.osdu.legal.get_legaltag(self.legaltag_to_create['name']) + + self.assertIsNotNone(legaltag["name"]) + + def test_003_batch_retrieve_legaltag(self): + legaltag_names = ["osdu-public-usa-dataset", self.legaltag_to_create['name']] + result = self.osdu.legal.batch_retrive_legaltags(legaltag_names) + + self.assertTrue(len(result['legalTags']) > 0) + + def test_004_update_legaltag(self): + result = self.osdu.legal.update_legaltag(self.legaltag_to_update) + + self.assertEqual(result['description'], self.legaltag_to_update['description']) + + def test_005_delete_legaltag(self): + tag_was_deleted = self.osdu.legal.delete_legaltag(self.legaltag_to_create['name']) + + self.assertTrue(tag_was_deleted) + + diff --git a/tests/test_data/test_create_legaltag.json b/tests/test_data/test_create_legaltag.json new file mode 100644 index 0000000..5a69d76 --- /dev/null +++ b/tests/test_data/test_create_legaltag.json @@ -0,0 +1,16 @@ +{ + "name": "osdu-testing-legal-tag-plz-delete", + "description": "Another default legal tag", + "properties": { + "countryOfOrigin": [ + "US" + ], + "contractId": "A1234", + "expirationDate": "2040-06-02", + "originator": "Default", + "dataType": "Public Domain Data", + "securityClassification": "Public", + "personalData": "No Personal Data", + "exportClassification": "EAR99" + } +} \ No newline at end of file diff --git a/tests/test_data/test_create_single_record.json b/tests/test_data/test_create_single_record.json index f160eb1..a41acbe 100644 --- a/tests/test_data/test_create_single_record.json +++ b/tests/test_data/test_create_single_record.json @@ -6,7 +6,7 @@ "Name": "Test Record 1" }, "legal": { - "legaltags": ["osdu-public-usa-dataset-1"], + "legaltags": ["osdu-public-usa-dataset"], "otherRelevantDataCountries": ["US"], "status": "compliant" }, @@ -23,7 +23,7 @@ "Name": "Test Record 2" }, "legal": { - "legaltags": ["osdu-public-usa-dataset-1"], + "legaltags": ["osdu-public-usa-dataset"], "otherRelevantDataCountries": ["US"], "status": "compliant" }, @@ -40,7 +40,7 @@ "Name": "Test Record 3" }, "legal": { - "legaltags": ["osdu-public-usa-dataset-1"], + "legaltags": ["osdu-public-usa-dataset"], "otherRelevantDataCountries": ["US"], "status": "compliant" }, diff --git a/tests/test_data/test_update_legaltag.json b/tests/test_data/test_update_legaltag.json new file mode 100644 index 0000000..e11cc01 --- /dev/null +++ b/tests/test_data/test_update_legaltag.json @@ -0,0 +1,6 @@ +{ + "name": "osdu-testing-legal-tag-plz-delete", + "contractId":"A1234", + "expirationDate":2222222222222, + "description": "new description" +} \ No newline at end of file diff --git a/tests/unit.py b/tests/unit.py index 1e4c77b..423261e 100644 --- a/tests/unit.py +++ b/tests/unit.py @@ -2,6 +2,7 @@ import hashlib import hmac from unittest import TestCase, mock +from time import time from osdu.client import ( AwsOsduClient, @@ -12,7 +13,7 @@ class TestAwsServicePrincipalOsduClient(TestCase): - @mock.patch('osdu.client._service_principal_util.ServicePrincipalUtil.get_service_principal_token') + @mock.patch('osdu.client._service_principal_util.ServicePrincipalUtil.get_service_principal_token', return_value=["testtoken",time()+ 999]) @mock.patch('boto3.Session') @mock.patch('base64.b64decode') def test_initialize_aws_client_with_args(self, mock_b64decode, mock_session, mock_sputil):