diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 300630b..369f5b9 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -9,10 +9,11 @@ import google.protobuf.internal.api_implementation from google.protobuf import symbol_database as _symbol_database -from snet.sdk.registry.models import StorageType +from snet.sdk.exceptions import NoGroupsFoundError, GroupNotFoundError, ServiceMetadataMismatchError +from snet.sdk.registry.models import StorageType, FileURI from snet.sdk.registry.organization_metadata import OrganizationMetadata from snet.sdk.registry.registry_contract import RegistryContract -from snet.sdk.registry.service_metadata import MPEServiceMetadata, ServiceMetadata +from snet.sdk.registry.service_metadata import ServiceMetadata, Group with warnings.catch_warnings(): # Suppress the eth-typing package`s warnings related to some new networks @@ -74,15 +75,18 @@ def create_service_client( options=None, concurrent_calls: int = 1, ): - - # Create and instance of the Config object, - # so we can create an instance of ClientLibGenerator + service_metadata = self._enhance_service_metadata(org_id, service_id) lib_generator = ClientLibGenerator(self.storage_provider, org_id, service_id) - # Download the proto file and generate stubs if needed + if service_metadata.service_api_source is not None: + service_api_source = service_metadata.service_api_source + else: + service_api_source = service_metadata.model_ipfs_hash + service_api_source = FileURI.from_raw_uri(service_api_source) + force_update = config.FORCE_UPDATE if force_update: - lib_generator.generate_client_library() + lib_generator.generate_client_library(service_api_source) else: path_to_pb_files = lib_generator.proto_dir pb_2_file_name = find_file_by_keyword( @@ -93,7 +97,7 @@ def create_service_client( ) if not pb_2_file_name or not pb_2_grpc_file_name: print("Generating client library...") - lib_generator.generate_client_library() + lib_generator.generate_client_library(service_api_source) if options is None: options = dict() @@ -103,8 +107,7 @@ def create_service_client( if payment_strategy is None: payment_strategy = payment_strategy_type.value() - service_metadata = self._enhance_service_metadata(org_id, service_id) - group = self._get_service_group_details(service_metadata, group_name) + group = self._get_service_group(org_id, service_id, service_metadata, group_name) service_stubs = self.get_service_stub(lib_generator) @@ -112,7 +115,6 @@ def create_service_client( _service_client = ServiceClient( org_id, service_id, - service_metadata, group, service_stubs, payment_strategy, @@ -162,7 +164,7 @@ def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) module_name = os.path.splitext(file_name)[0] return ModuleName(module_name) - def get_service_metadata(self, org_id, service_id): + def get_service_metadata(self, org_id, service_id) -> ServiceMetadata: service = self.registry_contract.get_service(org_id, service_id) return self.storage_provider.fetch_service_metadata(service.metadata_uri) @@ -170,28 +172,20 @@ def get_organization_metadata(self, org_id: str) -> OrganizationMetadata: org = self.registry_contract.get_org(org_id) return self.storage_provider.fetch_org_metadata(org.metadata_uri) - def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict: - return service_metadata["groups"][0] - - def _get_group_by_group_name( - self, service_metadata: MPEServiceMetadata, group_name: str - ) -> dict: - for group in service_metadata["groups"]: - if group["group_name"] == group_name: - return group - # TODO: configure exceptions - raise Exception() - - def _get_service_group_details( - self, service_metadata: MPEServiceMetadata, group_name: str - ) -> dict: - if len(service_metadata["groups"]) == 0: - raise Exception("No Groups found for given service, Please add group to the service") + def _get_service_group( + self, org_id: str, service_id: str, service_metadata: ServiceMetadata, group_name: str + ) -> Group: + if len(service_metadata.groups) == 0: + raise NoGroupsFoundError(org_id, service_id) if group_name is None: - return self._get_first_group(service_metadata) + return service_metadata.groups[0] - return self._get_group_by_group_name(service_metadata, group_name) + for group in service_metadata.groups: + if group.group_name == group_name: + return group + + raise GroupNotFoundError(org_id, service_id, group_name) def get_organization_list(self) -> list: return self.registry_contract.list_orgs() @@ -215,13 +209,68 @@ def publish_service_comprehensively( 5. publish service into Registry contract """ proto_uri = self.storage_provider.publish_proto(proto_dir, storage_type) - metadata.service_api_source = str(proto_uri) metadata.mpe_address = self.mpe_contract.contract.address + self._check_and_update_service_groups(org_id, metadata.groups) metadata_uri = self.storage_provider.publish_service_metadata(metadata, storage_type) + receipt = self.registry_contract.create_service( self.account, org_id, service_id, metadata_uri ) return receipt["status"] != 0 + + def update_service( + self, + org_id: str, + service_id: str, + metadata: ServiceMetadata, + proto_dir: Union[str, Path, None] = None, + storage_type: StorageType = StorageType.IPFS, + ) -> bool: + if proto_dir is not None: + proto_uri = self.storage_provider.publish_proto(proto_dir, storage_type) + metadata.service_api_source = str(proto_uri) + + if not metadata.mpe_address: + metadata.mpe_address = self.mpe_contract.contract.address + + self._check_and_update_service_groups(org_id, metadata.groups) + metadata_uri = self.storage_provider.publish_service_metadata(metadata, storage_type) + + receipt = self.registry_contract.update_service_metadata( + self.account, org_id, service_id, metadata_uri + ) + + return receipt["status"] != 0 + + def update_organization( + self, + org_id: str, + organization_metadata: OrganizationMetadata, + storage_type: StorageType = StorageType.IPFS, + ) -> bool: + metadata_uri = self.storage_provider.publish_organization_metadata( + organization_metadata, storage_type + ) + receipt = self.registry_contract.update_org_metadata(self.account, org_id, metadata_uri) + + return receipt["status"] != 0 + + def _check_and_update_service_groups( + self, org_id: str, service_groups: list[Group] + ) -> list[Group]: + org = self.registry_contract.get_org(org_id) + org_metadata = self.storage_provider.fetch_org_metadata(org.metadata_uri) + org_groups_map = {g.group_name: g for g in org_metadata.groups} + + for group in service_groups: + try: + group.group_id = org_groups_map[group.group_name].group_id + except KeyError: + raise ServiceMetadataMismatchError( + "All groups added to the service must also exist in the organization!" + ) + + return service_groups diff --git a/snet/sdk/client_lib_generator.py b/snet/sdk/client_lib_generator.py index bd58ec4..5b28f76 100644 --- a/snet/sdk/client_lib_generator.py +++ b/snet/sdk/client_lib_generator.py @@ -1,6 +1,7 @@ import os from pathlib import Path +from snet.sdk.registry.models import FileURI from snet.sdk.registry.storage_provider import StorageProvider from snet.sdk.utils.utils import compile_proto @@ -20,9 +21,9 @@ def __init__( self.proto_dir: Path = proto_dir if proto_dir else Path.home().joinpath(".snet") self.generate_directories_by_params() - def generate_client_library(self) -> None: + def generate_client_library(self, service_api_source: FileURI) -> None: try: - self.receive_proto_files() + self.receive_proto_files(service_api_source) compilation_result = compile_proto( entry_path=self.proto_dir, codegen_dir=self.proto_dir, @@ -35,8 +36,8 @@ def generate_client_library(self) -> None: f'in org with id "{self.org_id}" ' f"generated at {self.proto_dir}" ) - except Exception as e: - print(str(e)) + except Exception: + raise Exception("Error while proto compilation!") def generate_directories_by_params(self) -> None: if not self.proto_dir.is_absolute(): @@ -47,13 +48,7 @@ def create_service_client_libraries_path(self) -> None: self.proto_dir = self.proto_dir.joinpath(self.org_id, self.service_id, self.language) self.proto_dir.mkdir(parents=True, exist_ok=True) - def receive_proto_files(self) -> None: - metadata = self._metadata_provider.fetch_service_metadata( - org_id=self.org_id, service_id=self.service_id - ) - service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash") - - # Receive proto files + def receive_proto_files(self, service_api_source: FileURI) -> None: if self.proto_dir.exists(): self._metadata_provider.fetch_and_extract_proto(service_api_source, self.proto_dir) else: diff --git a/snet/sdk/exceptions.py b/snet/sdk/exceptions.py index 2e55c91..e83d213 100644 --- a/snet/sdk/exceptions.py +++ b/snet/sdk/exceptions.py @@ -1,7 +1,15 @@ +from grpc import RpcError + + class SnetSDKError(Exception): + """Base SNET SDK exception class""" + pass +# ==================== Blockchain interaction Errors ==================== + + class ContractError(SnetSDKError): pass @@ -27,7 +35,7 @@ def __init__(self, tx_hash: str, event_name: str): ) -class RegistryContractError(ContractError): +class RegistryContractError(ValueError, ContractError): pass @@ -41,18 +49,28 @@ def __init__(self, org_id: str, service_id: str): super().__init__(f"Service with org_id={org_id} service_id={service_id} doesn't exist!") -class UnauthorizedCallerError(TransactionError): +class UnauthorizedCallerError(RegistryContractError): pass class UnauthorizedOrgMemberError(UnauthorizedCallerError): def __init__(self, org_id: str, address: str): - super().__init__(f"Address {address} isn't owner or member of the organization {org_id}!") + super().__init__( + f"Address {address} isn't an owner or a member of the organization {org_id}!" + ) -class UnsupportedStorageTypeError(ValueError, SnetSDKError): - def __init__(self, storage_type: str): - super().__init__(f"Unsupported storage type: {storage_type}!") +class UnauthorizedOrgOwnerError(UnauthorizedCallerError): + def __init__(self, org_id: str, address: str): + super().__init__(f"Address {address} isn't an owner of the organization {org_id}!") + + +class IncorrectWalletAddressError(RegistryContractError): + def __init__(self, address: str): + super().__init__(f"Address {address} is not a correct Ethereum address!") + + +# ==================== Metadata Errors ==================== class MetadataMismatchError(ValueError, SnetSDKError): @@ -67,10 +85,18 @@ class OrganizationMetadataMismatchError(MetadataMismatchError): pass +# ==================== Storage Provider Errors ==================== + + class StorageProviderError(SnetSDKError): pass +class UnsupportedStorageTypeError(ValueError, StorageProviderError): + def __init__(self, storage_type: str): + super().__init__(f"Unsupported storage type: {storage_type}!") + + class LighthouseError(StorageProviderError): def __init__(self): super().__init__( @@ -78,7 +104,21 @@ def __init__(self): ) -class PublishProtoError(StorageProviderError, ValueError): +class IPFSError(StorageProviderError): + pass + + +class IPFSHashMismatchError(IPFSError): + def __init__(self): + super().__init__("IPFS hash mismatch with data") + + +class IPFSHashCheckError(IPFSError): + def __init__(self): + super().__init__("IPFS hash integrity check failed!") + + +class PublishProtoError(ValueError, StorageProviderError): pass @@ -90,3 +130,64 @@ def __init__(self, dir_path: str): class ProtoFilesNotFoundError(PublishProtoError): def __init__(self, dir_path: str): super().__init__(f"Cannot find any .proto file in {dir_path}!") + + +class ExtractingProtoError(ValueError, StorageProviderError): + pass + + +# ==================== Training Errors ==================== + + +class TrainingError(SnetSDKError): + pass + + +class WrongDatasetError(TrainingError): + def __init__(self, errors: list[str]): + self.errors = errors + exception_msg = "Dataset check failed:\n" + for check in errors: + exception_msg += f"\t{check}\n" + super().__init__(exception_msg) + + +class WrongMethodError(ValueError, TrainingError): + def __init__(self, method_name: str): + super().__init__(f"Method with name {method_name} not found!") + + +class NoTrainingError(ValueError, TrainingError): + def __init__(self, org_id: str, service_id: str): + super().__init__( + f"Training is not implemented for the service with org_id={org_id} and service_id={service_id}!" + ) + + +class GRPCError(RpcError, TrainingError): + def __init__(self, error: RpcError): + super().__init__(f"An error occurred during the grpc call: {error}.") + + +class NoSuchModelError(ValueError, TrainingError): + def __init__(self, model_id: str): + super().__init__(f"Model with id {model_id} not found!") + + +# ==================== Service Client Errors ==================== + + +class ServiceClientError(SnetSDKError): + pass + + +class NoGroupsFoundError(ServiceClientError): + def __init__(self, org_id: str, service_id: str): + super().__init__(f"Service with org_id={org_id} service_id={service_id} has no groups!") + + +class GroupNotFoundError(ServiceClientError): + def __init__(self, org_id: str, service_id: str, group_name: str): + super().__init__( + f"Service with org_id={org_id} service_id={service_id} has no group with group_name={group_name}!" + ) diff --git a/snet/sdk/mpe/mpe_contract.py b/snet/sdk/mpe/mpe_contract.py index 0ee3143..a403c5d 100644 --- a/snet/sdk/mpe/mpe_contract.py +++ b/snet/sdk/mpe/mpe_contract.py @@ -69,6 +69,6 @@ def channel_extend_and_add_funds(self, account: Account, channel_id, expiration, ) def _fund_escrow_account(self, account: Account, amount): - current_escrow_balance = self.balance(account.address) + current_escrow_balance = self.balance(account) if amount > current_escrow_balance: - self.deposit(amount - current_escrow_balance) + self.deposit(account, amount - current_escrow_balance) diff --git a/snet/sdk/registry/organization_metadata.py b/snet/sdk/registry/organization_metadata.py index 69c0d5d..6e2d05e 100644 --- a/snet/sdk/registry/organization_metadata.py +++ b/snet/sdk/registry/organization_metadata.py @@ -1,3 +1,5 @@ +import base64 +import secrets from typing import Optional, Literal from pydantic import BaseModel, Field @@ -5,6 +7,10 @@ from snet.sdk.exceptions import OrganizationMetadataMismatchError +def generate_group_id() -> str: + return base64.b64encode(secrets.token_bytes(32)).decode() + + class Description(BaseModel): url: str = Field(default="") description: str = Field(min_length=1) @@ -24,19 +30,19 @@ class Contact(BaseModel): class PaymentChannelStorageClient(BaseModel): connection_timeout: str = Field(default="5s") request_timeout: str = Field(default="5s") - endpoints: list[str] = Field(default=[]) + endpoints: list[str] = Field(min_length=1) class Payment(BaseModel): payment_address: str payment_expiration_threshold: int = Field(default=40320) payment_channel_storage_type: Literal["etcd"] = Field(default="etcd") - payment_channel_storage_client: Optional[PaymentChannelStorageClient] = Field(default=None) + payment_channel_storage_client: PaymentChannelStorageClient class Group(BaseModel): group_name: str - group_id: str + group_id: str = Field(default_factory=generate_group_id) payment: Payment @@ -46,8 +52,63 @@ class OrganizationMetadata(BaseModel): org_type: Literal["organization", "individual"] description: Optional[Description] = Field(default=None) assets: Optional[Assets] = Field(default=None) - contacts: list[Contact] = Field(default=[]) - groups: list[Group] = Field(default=[]) + contacts: list[Contact] = Field(default_factory=list) + groups: list[Group] = Field(default_factory=list) + + def add_group( + self, + payment_address: str, + payment_channel_storage_client_endpoints: list[str], + group_name: str = "default_group", + payment_expiration_threshold: int = 40320, + payment_channel_storage_client_connection_timeout: str = "5s", + payment_channel_storage_client_request_timeout: str = "5s", + ) -> "OrganizationMetadata": + existing_org_groups = [g.group_name for g in self.groups] + if group_name in existing_org_groups: + raise OrganizationMetadataMismatchError( + f"Group with group_name {group_name} already exists!" + ) + + self.groups.append( + Group( + group_name=group_name, + payment=Payment( + payment_address=payment_address, + payment_expiration_threshold=payment_expiration_threshold, + payment_channel_storage_client=PaymentChannelStorageClient( + connection_timeout=payment_channel_storage_client_connection_timeout, + request_timeout=payment_channel_storage_client_request_timeout, + endpoints=payment_channel_storage_client_endpoints, + ), + ), + ) + ) + + return self + + def add_contact( + self, + email: str = "", + phone: str = "", + contact_type: Literal["general", "support"] = "support", + ) -> "OrganizationMetadata": + self.contacts.append(Contact(email=email, phone=phone, contact_type=contact_type)) + + return self + + def add_description( + self, description: str, short_description: str, url: str = "" + ) -> "OrganizationMetadata": + self.description = Description( + url=url, description=description, short_description=short_description + ) + + return self + + def add_assets(self, hero_image: str) -> "OrganizationMetadata": + self.assets = Assets(hero_image=hero_image) + return self def generate_final_json(self): if not self.org_id: @@ -68,3 +129,11 @@ def generate_final_json(self): ) return self.model_dump_json(indent=2, exclude_none=True) + + def validate_metadata(self) -> tuple[bool, str]: + try: + self.generate_final_json() + except OrganizationMetadataMismatchError as e: + return False, str(e) + + return True, "" diff --git a/snet/sdk/registry/registry_contract.py b/snet/sdk/registry/registry_contract.py index aeba1c1..287ddf0 100644 --- a/snet/sdk/registry/registry_contract.py +++ b/snet/sdk/registry/registry_contract.py @@ -1,5 +1,6 @@ -from typing import Union +from typing import Optional +from eth_utils import is_checksum_address from snet.contracts import get_contract_object from web3.types import TxReceipt @@ -9,6 +10,8 @@ OrganizationNotFoundError, ServiceNotFoundError, UnauthorizedOrgMemberError, + UnauthorizedOrgOwnerError, + IncorrectWalletAddressError, ) from snet.sdk.registry.models import RawOrgData, OrgData, ServiceData, RawServiceData, FileURI from snet.sdk.utils.utils import ( @@ -71,17 +74,67 @@ def list_service_for_org(self, org_id: str) -> list[str]: # WRITE METHODS - def add_org_members( - self, account: Account, org_id: str, members: Union[str, list[str], None] - ): ... + def add_org_members(self, account: Account, org_id: str, new_members: list[str]) -> TxReceipt: + org = self.get_org(org_id) + + if account.address != org.owner: + raise UnauthorizedOrgOwnerError(account.address, org_id) + + for member in new_members: + if not is_checksum_address(member): + raise IncorrectWalletAddressError(member) + + return account.send_transaction( + self.contract.functions.changeOrganizationMetadataURI, + type_converter("bytes32")(org_id), + new_members, + ) + + def update_org_metadata( + self, account: Account, org_id: str, metadata_uri: FileURI + ) -> TxReceipt: + org = self.get_org(org_id) - def update_org_metadata(self, account: Account, org_id: str, metadata_uri: str): ... + if account.address != org.owner: + raise UnauthorizedOrgOwnerError(account.address, org_id) + + return account.send_transaction( + self.contract.functions.changeOrganizationMetadataURI, + type_converter("bytes32")(org_id), + metadata_uri.to_bytes_uri(), + ) - def change_org_owner(self, account: Account, org_id: str, new_owner: str): ... + def change_org_owner(self, account: Account, org_id: str, new_owner: str) -> TxReceipt: + org = self.get_org(org_id) + + if account.address != org.owner: + raise UnauthorizedOrgOwnerError(account.address, org_id) + + if not is_checksum_address(new_owner): + raise IncorrectWalletAddressError(new_owner) + + return account.send_transaction( + self.contract.functions.changeOrganizationOwner, + type_converter("bytes32")(org_id), + new_owner, + ) def create_org( - self, account: Account, org_id: str, metadata_uri: str, members: Union[str, list[str], None] - ): ... + self, + account: Account, + org_id: str, + metadata_uri: FileURI, + members: Optional[list[str]] = None, + ) -> TxReceipt: + if members is None: + members = [] + + return account.send_transaction( + self.contract.functions.createOrganization, + type_converter("bytes32")(org_id), + metadata_uri.to_bytes_uri(), + members, + ) def create_service( self, account: Account, org_id: str, service_id: str, metadata_uri: FileURI @@ -98,14 +151,55 @@ def create_service( metadata_uri.to_bytes_uri(), ) - def delete_org(self, account: Account, org_id: str): ... + def delete_org(self, account: Account, org_id: str) -> TxReceipt: + org = self.get_org(org_id) - def delete_service(self, account: Account, org_id: str, service_id: str): ... + if account.address != org.owner: + raise UnauthorizedOrgOwnerError(account.address, org_id) + + return account.send_transaction( + self.contract.functions.deleteOrganization, type_converter("bytes32")(org_id) + ) + + def delete_service(self, account: Account, org_id: str, service_id: str): + org = self.get_org(org_id) + if account.address not in org.members: + raise UnauthorizedOrgMemberError(account.address, org_id) + + self.get_service(org_id, service_id) # to check if the service exists + + return account.send_transaction( + self.contract.functions.deleteServiceRegistration, + type_converter("bytes32")(org_id), + type_converter("bytes32")(service_id), + ) def remove_org_members( - self, account: Account, org_id: str, members_to_remove: Union[str, list[str]] - ): ... + self, account: Account, org_id: str, members_to_remove: list[str] + ) -> TxReceipt: + org = self.get_org(org_id) + + if account.address != org.owner: + raise UnauthorizedOrgOwnerError(account.address, org_id) + + return account.send_transaction( + self.contract.functions.deleteOrganization, + type_converter("bytes32")(org_id), + members_to_remove, + ) def update_service_metadata( - self, account: Account, org_id: str, service_id: str, metadata_uri: str - ): ... + self, account: Account, org_id: str, service_id: str, metadata_uri: FileURI + ) -> TxReceipt: + org = self.get_org(org_id) + if account.address not in org.members: + raise UnauthorizedOrgMemberError(account.address, org_id) + + self.get_service(org_id, service_id) # to check if the service exists + + return account.send_transaction( + self.contract.functions.updateServiceRegistration, + type_converter("bytes32")(org_id), + type_converter("bytes32")(service_id), + metadata_uri.to_bytes_uri(), + ) diff --git a/snet/sdk/registry/service_metadata.py b/snet/sdk/registry/service_metadata.py index 78718fe..e1ee6e9 100644 --- a/snet/sdk/registry/service_metadata.py +++ b/snet/sdk/registry/service_metadata.py @@ -1,82 +1,15 @@ -""" -Functions for manipulating service metadata +from typing import Literal, Optional -Metadata format: ----------------------------------------------------- -version - used to track format changes (current version is 1) -display_name - Display name of the service -encoding - Service encoding (proto or json) -service_type - Service type (grpc, jsonrpc or process) -service_description - Service description (arbitrary field) -payment_expiration_threshold - Service will reject payments with expiration less - than current_block + payment_expiration_threshold. - This field should be used by the client with caution. - Client should not accept arbitrary payment_expiration_threshold -model_ipfs_hash - IPFS HASH to the .tar archive of protobuf service specification -mpe_address - Address of MultiPartyEscrow contract. - Client should use it exclusively for cross-checking of mpe_address, - (because service can attack via mpe_address) - Daemon can use it directly if authenticity of metadata is confirmed -pricing {} - Pricing model - Possible pricing models: - 1. Fixed price - price_model - "fixed_price" - price_in_cogs - unique fixed price in cogs for all method (1 FET = 10^18 cogs) - (other pricing models can be easily supported) -groups [] - group is the number of endpoints which shares same payment channel; - grouping strategy is defined by service provider; - for example service provider can use region name as group name - group_name - unique name of the group (human readable) - group_id - unique id of the group (random 32 byte string in base64 encoding) - payment_address - Ethereum address to recieve payments -endpoints[] - address in the off-chain network to provide a service - group_name - endpoint - unique endpoint identifier (ip:port) - -assets {} - asset type and its ipfs value/values -""" - -import re -import json -import base64 -import secrets - -from collections import defaultdict -from enum import Enum -from typing import Literal, Any, Optional - -from pydantic import BaseModel, Field, model_validator, ValidationInfo +from pydantic import BaseModel, Field from snet.sdk.exceptions import ServiceMetadataMismatchError from snet.sdk.registry.models import FileURI -from snet.sdk.registry.organization_metadata import Payment -from snet.sdk.utils.utils import is_valid_endpoint - - -# Supported Asset types -class AssetType(Enum): - HERO_IMAGE = "hero_image" - IMAGES = "images" - DOCUMENTATION = "documentation" - TERMS_OF_USE = "terms_of_use" - - @staticmethod - def is_single_value(asset_type): - if ( - asset_type == AssetType.HERO_IMAGE.value - or asset_type == AssetType.DOCUMENTATION.value - or asset_type == AssetType.TERMS_OF_USE.value - ): - return True - - -def generate_group_id() -> str: - return base64.b64encode(secrets.token_bytes(32)).decode() +from snet.sdk.registry.organization_metadata import Payment, generate_group_id class Pricing(BaseModel): - price_model: Literal["fixed_price", "method_price"] = Field(default="fixed_price") - price_in_cogs: int = Field(ge=1, default=1) + price_model: Literal["fixed_price", "method_price"] + price_in_cogs: int = Field(ge=1) default: bool = Field(default=True) @@ -84,10 +17,10 @@ class Group(BaseModel): group_name: str = Field(min_length=1, default="default_group") group_id: str = Field(default_factory=generate_group_id, init=False) free_calls: int = Field(ge=1, default=3) - free_call_signer_address: str = Field(default="") - daemon_addresses: list[str] = Field(default=[]) - endpoints: list[str] = Field(default=[]) - pricing: list[Pricing] = Field(default=[]) + free_call_signer_address: str + daemon_addresses: list[str] = Field(default_factory=list) + endpoints: list[str] + pricing: list[Pricing] payment: Optional[Payment] = Field( default=None ) # The field from org metadata for service client functionality @@ -120,27 +53,82 @@ class ServiceMetadata(BaseModel): service_api_source: Optional[str] = Field(default=None, init=False) model_ipfs_hash: Optional[str] = Field(min_length=1, default=None, init=False, deprecated=True) mpe_address: Optional[str] = Field(default=None, init=False) - groups: list[Group] = Field(default=[]) - service_description: Optional[ServiceDescription] = Field(default=None) - media: list[Media] = Field(default=[]) - contributors: list[Contributor] = Field(default=[]) - tags: list[str] = Field(default=[]) - - @model_validator(mode="before") - @classmethod - def restrict_deprecated_fields(cls, data: Any, info: ValidationInfo) -> Any: - if not isinstance(data, dict): - return data + groups: list[Group] = Field(default_factory=list) + service_description: ServiceDescription = Field(default_factory=ServiceDescription) + media: list[Media] = Field(default_factory=list) + contributors: list[Contributor] = Field(default_factory=list) + tags: list[str] = Field(default_factory=list) + + def add_group( + self, + free_call_signer_address: str, + endpoints: list[str], + group_name: str = "default_group", + free_calls_amount: int = 3, + daemon_addresses: Optional[list[str]] = None, + price_model: Literal["fixed_price", "method_price"] = "fixed_price", + price_in_cogs: int = 1, + default: bool = True, + ) -> "ServiceMetadata": + existing_group_names = [g.group_name for g in self.groups] + if group_name in existing_group_names: + raise ServiceMetadataMismatchError( + f"Group with group_name {group_name} already exists!" + ) + if daemon_addresses is None: + daemon_addresses = [] + self.groups.append( + Group( + group_name=group_name, + free_calls=free_calls_amount, + free_call_signer_address=free_call_signer_address, + daemon_addresses=daemon_addresses, + endpoints=endpoints, + pricing=Pricing( + price_model=price_model, price_in_cogs=price_in_cogs, default=default + ), + ) + ) - is_fetching = info.context and info.context.get("from_storage") is True + return self + + def add_description( + self, + url: Optional[str] = None, + short_description: Optional[str] = None, + description: Optional[str] = None, + ) -> "ServiceMetadata": + for key, value in locals().items(): + if value is not None: + setattr(self.service_description, key, value) + + return self + + def add_media( + self, + url: str, + file_type: Literal["image", "video", "archive"], + asset_type: Literal["hero_image", "proto_file", "demo_component"], + alt_text: str = "", + ) -> "ServiceMetadata": + orders = [m.order for m in self.media] + if orders: + order = max(orders) + 1 + else: + order = 1 - if not is_fetching and data.get("model_ipfs_hash"): - raise ValueError( - "The 'model_ipfs_hash' field is deprecated and cannot be used " - "to create new metadata. Please use 'service_api_source' instead." + self.media.append( + Media( + order=order, url=url, file_type=file_type, asset_type=asset_type, alt_text=alt_text ) + ) - return data + return self + + def add_contributor(self, name: str, email: str = "") -> "ServiceMetadata": + self.contributors.append(Contributor(name=name, email_id=email)) + + return self def generate_final_json(self): if self.service_api_source is None: @@ -166,525 +154,10 @@ def generate_final_json(self): return self.model_dump_json(indent=2, exclude_none=True) + def validate_metadata(self) -> tuple[bool, str]: + try: + self.generate_final_json() + except ServiceMetadataMismatchError as e: + return False, str(e) -# TODO: we should use some standard solution here -class MPEServiceMetadata: - def __init__(self): - """init with modelIPFSHash""" - self.m = { - "version": 1, - "display_name": "", - "encoding": "grpc", # grpc by default - "service_type": "grpc", # grpc by default - # one week by default (15 sec block, 24*60*60*7/15) - "model_ipfs_hash": "", - "mpe_address": "", - "groups": [], - "assets": {}, - "media": [], - "tags": [], - } - - def set_simple_field(self, f, v): - if ( - f != "display_name" - and f != "encoding" - and f != "model_ipfs_hash" - and f != "mpe_address" - and f != "service_type" - and f != "payment_expiration_threshold" - and f != "service_description" - ): - raise Exception("unknown field in MPEServiceMetadata") - self.m[f] = v - - def set_fixed_price_in_cogs(self, group_name, price): - if not isinstance(price, int): - raise Exception("Price should have int type") - - if not self.is_group_name_exists(group_name): - raise Exception("the group %s is not present" % str(group_name)) - - for group in self.m["groups"]: - if group["group_name"] == group_name: - is_fixed_price_enabled = False - # default=True it will change when we will go live with method level pricing - if "pricing" in group: - for pricing in group["pricing"]: - if pricing["price_model"] == "fixed_price": - is_fixed_price_enabled = True - pricing["price_in_cogs"] = price - if not is_fixed_price_enabled: - group["pricing"].append( - { - "price_model": "fixed_price", - "price_in_cogs": price, - "default": True, - } - ) - else: - group["pricing"] = [ - { - "price_model": "fixed_price", - "price_in_cogs": price, - "default": True, - } - ] - - def set_method_price_in_cogs(self, group_name, package_name, service_name, method, price): - if not isinstance(price, int): - raise Exception("Price should have int type") - - if not self.is_group_name_exists(group_name): - raise Exception("the group %s is not present" % str(group_name)) - - groups = self.m["groups"] - for group in groups: - if group["group_name"] == group_name: - service_name = service_name - package_name = package_name - method_pricing = {"method_name": method, "price_in_cogs": price} - pricings = [] - - if "pricings" in group: - pricings = group["pricings"] - - fixed_price_method_model_exist = False - for pricing in pricings: - if pricing["price_model"] == "fixed_price_per_method": - fixed_price_method_model_exist = True - - if "details" in pricing: - fixed_price_method_pricing_for_service_exist = False - for detail in pricing["details"]: - if detail["service_name"] == service_name: - # adding new method pricing for existing service - fixed_price_method_pricing_for_service_exist = True - detail["method_pricing"].append(method_pricing) - - if not fixed_price_method_pricing_for_service_exist: - # pricing for new method for new service - pricing["details"].append( - { - "service_name": service_name, - "method_pricing": [method_pricing], - } - ) - else: - pricing["details"] = [ - { - "service_name": service_name, - "method_pricing": [method_pricing], - } - ] - - if not fixed_price_method_model_exist: - fixed_price_per_method = { - "package_name": package_name, - "price_model": "fixed_price_per_method", - "details": [ - { - "service_name": service_name, - "method_pricing": [method_pricing], - } - ], - } - group["pricings"] = [fixed_price_per_method] - - def add_group(self, group_name): - """Return new group_id in base64""" - if self.is_group_name_exists(group_name): - raise Exception('the group "%s" is already present' % str(group_name)) - - self.m["groups"] += [{"group_name": group_name}] - - def remove_group(self, group_name): - for group in self.m["groups"]: - if group["group_name"] == group_name: - self.m["groups"].remove(group) - - def get_tags(self): - tags = [] - if "tags" in self.m: - tags = self.m["tags"] - return tags - - def add_tag(self, tag_name): - if "tags" not in self.m: - self.m["tags"] = [] - - if tag_name in self.m["tags"]: - print(f"The tag {str(tag_name)} is already present") - return - self.m["tags"] += [tag_name] - - def remove_tag(self, tag_name): - if "tags" not in self.m: - self.m["tags"] = [] - - if tag_name not in self.m["tags"]: - print(f"The tag {str(tag_name)} is not found") - return - self.m["tags"].remove(tag_name) - - def add_asset(self, asset_ipfs_hash, asset_type): - # Check if we need to validation if same asset type is added twice if we need to add it or replace the existing one - - if "assets" not in self.m: - self.m["assets"] = {} - - # hero image will contain the single value - if AssetType.is_single_value(asset_type): - self.m["assets"][asset_type] = asset_ipfs_hash - - # images can contain multiple value - elif asset_type == AssetType.IMAGES.value: - if asset_type in self.m["assets"]: - self.m["assets"][asset_type].append(asset_ipfs_hash) - else: - self.m["assets"][asset_type] = [asset_ipfs_hash] - else: - raise Exception("Invalid asset type %s" % asset_type) - - def remove_all_assets(self): - self.m["assets"] = {} - - def remove_asset(self, asset_type): - if "assets" in self.m: - if AssetType.is_single_value(asset_type): - self.m["assets"][asset_type] = "" - elif asset_type == AssetType.IMAGES.value: - self.m["assets"][asset_type] = [] - else: - raise Exception("Invalid asset type %s" % asset_type) - - def add_endpoint_to_group(self, group_name, endpoint): - if re.match("^\w+://", endpoint) is None: - # TODO: Default to https when our tutorials show setting up a ssl certificate as well - endpoint = "http://" + endpoint - if not is_valid_endpoint(endpoint): - raise Exception("Endpoint is not a valid URL") - if not self.is_group_name_exists(group_name): - raise Exception("the group %s is not present" % str(group_name)) - if endpoint in self.get_all_endpoints_for_group(group_name): - raise Exception("the endpoint %s is already present" % str(endpoint)) - - groups = self.m["groups"] - for group in groups: - if group["group_name"] == group_name: - if "endpoints" in group: - group["endpoints"].append(endpoint) - else: - group["endpoints"] = [endpoint] - - def remove_all_endpoints_for_group(self, group_name): - if not self.is_group_name_exists(group_name): - raise Exception("Group name does not exist %s", group_name) - - groups = self.m["groups"] - for group in groups: - if group["group_name"] == group_name: - group["endpoints"] = [] - - def is_group_name_exists(self, group_name): - """check if group with given name is already exists""" - groups = self.m["groups"] - for g in groups: - if g["group_name"] == group_name: - return True - return False - - def get_group_by_group_id(self, group_id): - """return group with given group_id (return None if it doesn't exist)""" - group_id_base64 = base64.b64encode(group_id).decode("ascii") - groups = self.m["groups"] - for g in groups: - if g["group_id"] == group_id_base64: - return g - return None - - def set_free_calls_for_group(self, group_name, free_calls): - groups = self.m["groups"] - for g in groups: - if g["group_name"] == group_name: - g["free_calls"] = free_calls - - def set_freecall_signer_address(self, group_name, signer_address): - groups = self.m["groups"] - for g in groups: - if g["group_name"] == group_name: - g["free_call_signer_address"] = signer_address - - def get_json(self): - return json.dumps(self.m) - - def get_json_pretty(self): - return json.dumps(self.m, indent=4) - - def set_from_json(self, j): - # TODO: we probaly should check the consistensy of loaded json here - # check that it contains required fields - self.m = json.loads(j) - if "tags" not in self.m: - self.m["tags"] = [] - - def load(self, file_name): - with open(file_name) as f: - self.set_from_json(f.read()) - - def save_pretty(self, file_name): - with open(file_name, "w") as f: - f.write(self.get_json_pretty()) - - def __getitem__(self, key): - return self.m[key] - - def __contains__(self, key): - return key in self.m - - def get(self, key, default=None): - return self.m.get(key, default) - - def get_group_name_nonetrick(self, group_name=None): - """In all getter function in case of single payment group, group_name can be None""" - groups = self.m["groups"] - if len(groups) == 0: - raise Exception("Cannot find any groups in metadata") - if not group_name: - if len(groups) > 1: - raise Exception( - "We have more than one payment group in metadata, so group_name should be specified" - ) - return groups[0]["group_name"] - return group_name - - def get_group(self, group_name=None): - group_name = self.get_group_name_nonetrick(group_name) - for g in self.m["groups"]: - if g["group_name"] == group_name: - return g - raise Exception('Cannot find group "%s" in metadata' % group_name) - - def get_group_id_base64(self, group_name=None): - return self.get_group(group_name)["group_id"] - - def get_group_id(self, group_name=None): - return base64.b64decode(self.get_group_id_base64(group_name)) - - def get_payment_address(self, group_name=None): - return self.get_group(group_name)["payment_address"] - - def add_daemon_address_to_group(self, group_name, daemon_address): - groups = self.m["groups"] - if not self.is_group_name_exists(group_name): - raise Exception('Cannot find group "%s" in metadata' % group_name) - for group in groups: - if group["group_name"] == group_name: - if "daemon_addresses" in group: - group["daemon_addresses"].append(daemon_address) - else: - group["daemon_addresses"] = [daemon_address] - - def remove_all_daemon_addresses_for_group(self, group_name): - groups = self.m["groups"] - if not self.is_group_name_exists(group_name): - raise Exception('Cannot find group "%s" in metadata' % group_name) - for group in groups: - if group["group_name"] == group_name: - group["daemon_addresses"] = [] - - def get_all_endpoints_for_group(self, group_name): - for group in self.m["groups"]: - if group["group_name"] == group_name: - if "endpoints" in group: - return group["endpoints"] - return [] - - def get_all_group_endpoints(self): - group_endpoints = {} - for group in self.m["groups"]: - if "endpoints" in group: - group_endpoints[group["group_name"]] = group["endpoints"] - return group_endpoints - - def get_all_endpoints_with_group_name(self): - endpts_with_grp = defaultdict(list) - for e in self.m["endpoints"]: - endpts_with_grp[e["group_name"]].append(e["endpoint"]) - return endpts_with_grp - - def get_endpoints_for_group(self, group_name=None): - group_name = self.get_group_name_nonetrick(group_name) - return [e["endpoint"] for e in self.m["endpoints"] if e["group_name"] == group_name] - - def add_contributor(self, name, email_id): - if "contributors" in self.m: - contributors = self.m["contributors"] - else: - contributors = [] - - contributors.append({"name": name, "email_id": email_id}) - self.m["contributors"] = contributors - - def remove_contributor_by_email(self, email_id): - self.m["contributors"] = [ - contributor - for contributor in self.m["contributors"] - if contributor["email_id"] != email_id - ] - - def group_init(self, group_name): - """Required values for creating a new payment group. - - Args: - group_name: If org contains only 1 payment group -> default_group, ask user for other groups otherwise. - - Raises: - ValueError: User enters non-integer value for `fixed_price.` - Exception: User enters same endpoints. - """ - self.add_group(group_name) - while True: - try: - fixed_price = int(input("Set fixed price: ")) - except ValueError: - print("Enter a valid integer.") - else: - self.set_fixed_price_in_cogs(group_name, fixed_price) - break - while True: - try: - endpoints = input("Add endpoints as comma separated values: ").split(",") - if endpoints[0] == "": - print("Endpoints required.") - else: - for endpoint in endpoints: - self.add_endpoint_to_group(group_name, endpoint.strip()) - break - except Exception as e: - print(e) - while True: - daemon_addresses = input("Add daemon addresses as comma separated values: ").split(",") - if daemon_addresses[0] == "": - print("Daemon address required.") - else: - for daemon_address in daemon_addresses: - self.add_daemon_address_to_group(group_name, daemon_address.strip()) - break - if input("Free calls included? [y/n] ").lower() == "y": - self.set_free_calls_for_group(group_name, int(input("free calls: (15) ") or 15)) - self.set_freecall_signer_address(group_name, input("free call signer address: ")) - - def add_media(self, url, media_type, hero_img=False): - """Add new individual media to service metadata.""" - if "media" not in self.m: - self.m["media"] = [] - individual_media = {} - if hero_img: - assert media_type == "image", f"{media_type.upper()} media-type cannot be a hero-image." - assert not self._is_asset_type_exists(), ( - "Hero-image already exists (only 1 unique hero-image allowed.)" - ) - individual_media["asset_type"] = ( - AssetType.HERO_IMAGE.value - ) # Dependency with AssetType, fix if obsolete - if len(self.m["media"]) == 0: - individual_media["order"] = 1 - else: - individual_media["order"] = self.m["media"][-1]["order"] + 1 - individual_media["url"] = url - individual_media["file_type"] = media_type - if media_type == "image": - individual_media["alt_text"] = "hover_on_the_image_text" - else: - individual_media["alt_text"] = "hover_on_the_video_url" - self.m["media"].append(individual_media) - - def remove_media(self, order): - """Remove individual media from service metadata using unique order key.""" - assert len(self.m["media"]) > 0, "No media content to remove." - assert order > 0, "Order of individual media starts from 1." - del_position = -1 - for i in range(len(self.m["media"])): - if order == self.m["media"][i]["order"]: - del self.m["media"][i] - del_position = i - break - if del_position == -1: - raise Exception(f"Media with order: {order} not found.") - for i in range(del_position, len(self.m["media"])): - self.m["media"][i]["order"] -= 1 - - def remove_all_media(self): - """Remove all individual media from metadata.""" - self.m["media"].clear() - - def swap_media_order(self, move_from, move_to): - """Swap orders of two different media given their individual orders (move_from, move_to).""" - assert len(self.m["media"]) + 1 > move_from > 0, f"Order {move_from} out of bounds." - assert len(self.m["media"]) + 1 > move_to > 0, f"Order {move_to} out of bounds." - self.m["media"][move_to - 1], self.m["media"][move_from - 1] = ( - self.m["media"][move_from - 1], - self.m["media"][move_to - 1], - ) - ( - self.m["media"][move_to - 1]["order"], - self.m["media"][move_from - 1]["order"], - ) = ( - self.m["media"][move_from - 1]["order"], - self.m["media"][move_to - 1]["order"], - ) - - def change_media_order(self): - """Mini REPL to change order of all individual media""" - order_range = range(1, len(self.m["media"]) + 1) - available_orders = list(order_range) - for individual_media in self.m["media"]: - print( - f"File Type: {individual_media['file_type']}, Current Order: {individual_media['order']}" - ) - while True: - try: - new_order = int(input(f"Enter new order for {individual_media['url']}: ")) - except ValueError: - print("Error: Order entered not a number. Try again.") - else: - if new_order in available_orders: - individual_media["order"] = new_order - available_orders.remove(new_order) - break - elif new_order not in order_range: - print( - f"Media array contains only {len(self.m['media'])} items. Enter order between [{order_range.start}, {order_range.stop - 1}]" - ) - else: - print(f"Order already taken. Available orders: {available_orders}") - self.m["media"].sort(key=lambda x: x["order"]) - - def _is_asset_type_exists(self): - """Return boolean on whether asset type already exists""" - media = self.m["media"] - for individual_media in media: - if "asset_type" in individual_media: - return True - return False - - def add_description(self): - if "service_description" not in self.m: - self.m["service_description"] = { - "url": input("user guide url: "), - "long_description": input("service long description: "), - "short_description": input("service short description: "), - } - - -def load_mpe_service_metadata(f) -> MPEServiceMetadata: - metadata = MPEServiceMetadata() - metadata.load(f) - return metadata - - -def mpe_service_metadata_from_json(j) -> MPEServiceMetadata: - metadata = MPEServiceMetadata() - metadata.set_from_json(j) - return metadata + return True, "" diff --git a/snet/sdk/registry/storage_provider.py b/snet/sdk/registry/storage_provider.py index af71ecd..85f11cf 100644 --- a/snet/sdk/registry/storage_provider.py +++ b/snet/sdk/registry/storage_provider.py @@ -15,6 +15,9 @@ LighthouseError, WrongDirectoryError, ProtoFilesNotFoundError, + IPFSHashMismatchError, + IPFSHashCheckError, + ExtractingProtoError, ) from snet.sdk.registry.organization_metadata import OrganizationMetadata from snet.sdk.registry.models import StorageType, FileURI @@ -26,9 +29,8 @@ class StorageProvider(object): def __init__(self): self._ipfs_client = ipfshttpclient.connect(config.IPFS_ENDPOINT) self._lighthouse_client = Lighthouse(config.LIGHTHOUSE_TOKEN) - self._ipfs_client.add() - def fetch_org_metadata(self, metadata_uri: FileURI): + def fetch_org_metadata(self, metadata_uri: FileURI) -> OrganizationMetadata: org_metadata_json = self._get_from_storage(metadata_uri) raw_org_metadata = json.loads(org_metadata_json) org_metadata = OrganizationMetadata(**raw_org_metadata) @@ -42,9 +44,8 @@ def fetch_service_metadata(self, metadata_uri: FileURI) -> ServiceMetadata: return service_metadata - def fetch_and_extract_proto(self, service_api_source, proto_dir): - tar_uri = FileURI.from_raw_uri(service_api_source) - spec_tar = self._get_from_storage(tar_uri) + def fetch_and_extract_proto(self, service_api_source: FileURI, proto_dir) -> None: + spec_tar = self._get_from_storage(service_api_source, decode=False) self._safe_extract_proto(spec_tar, proto_dir) def publish_organization_metadata( @@ -144,10 +145,10 @@ def _get_from_ipfs_and_checkhash(self, ipfs_hash: str, validate: bool = False) - actual_digest = h.digest() if actual_digest != expected_digest: - raise Exception("IPFS hash mismatch with data") + raise IPFSHashMismatchError() except Exception as e: - raise ValueError(f"Integrity check failed: {str(e)}") from e + raise IPFSHashCheckError() from e return data @@ -161,15 +162,15 @@ def _safe_extract_proto(spec_tar: bytes, proto_dir: Union[str, Path]) -> None: for m in f.getmembers(): if not m.isfile(): - raise ValueError( + raise ExtractingProtoError( f"Security/Format Error: Tarball contains a non-file item: '{m.name}'" ) if Path(m.name).parent != Path("."): - raise ValueError( + raise ExtractingProtoError( f"Format Error: Tarball contains nested paths ('{m.name}'). Only flat archives are supported." ) if not m.name.endswith(".proto"): - raise ValueError( + raise ExtractingProtoError( f"Format Error: Unexpected file type '{m.name}'. Only .proto files allowed." ) target_file = dest_dir / m.name diff --git a/snet/sdk/service_client.py b/snet/sdk/service_client.py index 30ac854..2078545 100644 --- a/snet/sdk/service_client.py +++ b/snet/sdk/service_client.py @@ -21,7 +21,7 @@ PrePaidPaymentStrategy, ) from snet.sdk.resources.root_certificate import certificate -from snet.sdk.registry.service_metadata import MPEServiceMetadata +from snet.sdk.registry.service_metadata import Group from snet.sdk.types import ModuleName, ServiceStub from snet.sdk.utils.utils import ( RESOURCES_PATH, @@ -29,7 +29,7 @@ find_file_by_keyword, ) from snet.sdk.training.training import Training -from snet.sdk.training.exceptions import NoTrainingException +from snet.sdk.exceptions import NoTrainingError from snet.sdk.utils.call_utils import create_intercept_call_func @@ -38,8 +38,7 @@ def __init__( self, org_id: str, service_id: str, - service_metadata: MPEServiceMetadata, - group: dict, + group: Group, service_stubs: list[ServiceStub], payment_strategy, options: dict, @@ -53,7 +52,6 @@ def __init__( ): self.org_id = org_id self.service_id = service_id - self.service_metadata = service_metadata self.group = group self.payment_strategy = payment_strategy if isinstance(payment_strategy, PrePaidPaymentStrategy): @@ -69,7 +67,7 @@ def __init__( self.payment_channel_provider = payment_channel_provider self.path_to_pb_files = path_to_pb_files - self.expiry_threshold: int = self.group["payment"]["payment_expiration_threshold"] + self.expiry_threshold: int = self.group.payment.payment_expiration_threshold self.__base_grpc_channel = self._get_grpc_channel() _intercept_call_func = create_intercept_call_func( self.payment_strategy.get_payment_metadata, self @@ -120,9 +118,7 @@ def get_grpc_base_channel(self) -> grpc.Channel: def _get_grpc_channel(self) -> grpc.Channel: endpoint = self.options.get("endpoint", None) if endpoint is None: - endpoint = self.service_metadata.get_all_endpoints_for_group(self.group["group_name"])[ - 0 - ] + endpoint = self.group.endpoints[0] endpoint_object = urlparse(endpoint) if endpoint_object.port is not None: channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port) @@ -161,8 +157,8 @@ def _filter_existing_channels_from_new_payment_channels( def load_open_channels(self) -> list[PaymentChannel]: current_block_number = self.sdk_web3.eth.block_number - payment_address = self.group["payment"]["payment_address"] - group_id = base64.b64decode(str(self.group["group_id"])) + payment_address = self.group.payment.payment_address + group_id = base64.b64decode(str(self.group.group_id)) new_payment_channels = self.payment_channel_provider.get_past_open_channels( self.account, payment_address, @@ -185,7 +181,7 @@ def update_channel_states(self) -> list[PaymentChannel]: return self.payment_channels def default_channel_expiration(self) -> int: - current_block_number = self.sdk_web3.eth.get_block("latest").number + current_block_number = self.sdk_web3.eth.block_number return current_block_number + self.expiry_threshold def _generate_payment_channel_state_service_client(self) -> Any: @@ -198,8 +194,8 @@ def get_mpe_balance(self): return self.mpe_contract.balance(self.account) def open_channel(self, amount: int, expiration: int) -> PaymentChannel: - payment_address = self.group["payment"]["payment_address"] - group_id = base64.b64decode(str(self.group["group_id"])) + payment_address = self.group.payment.payment_address + group_id = base64.b64decode(str(self.group.group_id)) return self.payment_channel_provider.open_channel( self.account, amount, @@ -210,8 +206,8 @@ def open_channel(self, amount: int, expiration: int) -> PaymentChannel: ) def deposit_and_open_channel(self, amount: int, expiration: int) -> PaymentChannel: - payment_address = self.group["payment"]["payment_address"] - group_id = base64.b64decode(str(self.group["group_id"])) + payment_address = self.group.payment.payment_address + group_id = base64.b64decode(str(self.group.group_id)) return self.payment_channel_provider.deposit_and_open_channel( self.account, amount, @@ -222,7 +218,7 @@ def deposit_and_open_channel(self, amount: int, expiration: int) -> PaymentChann ) def get_price(self) -> int: - return self.group["pricing"][0]["price_in_cogs"] + return self.group.pricing[0].price_in_cogs def generate_signature(self, message: bytes) -> bytes: return bytes( @@ -247,13 +243,13 @@ def get_service_details(self) -> tuple[str, str, str, str]: self.org_id, self.service_id, self.group["group_id"], - self.service_metadata.get_all_endpoints_for_group(self.group["group_name"])[0], + self.group.endpoints[0], ) @property def training(self) -> Training: if not self.__training.is_enabled: - raise NoTrainingException(self.org_id, self.service_id) + raise NoTrainingError(self.org_id, self.service_id) return self.__training def _get_training_model_id(self, model_id: str) -> Any: diff --git a/snet/sdk/training/exceptions.py b/snet/sdk/training/exceptions.py deleted file mode 100644 index f1965bb..0000000 --- a/snet/sdk/training/exceptions.py +++ /dev/null @@ -1,32 +0,0 @@ -from grpc import RpcError - - -class WrongDatasetException(Exception): - def __init__(self, errors: list[str]): - self.errors = errors - exception_msg = "Dataset check failed:\n" - for check in errors: - exception_msg += f"\t{check}\n" - super().__init__(exception_msg) - - -class WrongMethodException(Exception): - def __init__(self, method_name: str): - super().__init__(f"Method with name {method_name} not found!") - - -class NoTrainingException(Exception): - def __init__(self, org_id: str, service_id: str): - super().__init__( - f"Training is not implemented for the service with org_id={org_id} and service_id={service_id}!" - ) - - -class GRPCException(RpcError): - def __init__(self, error: RpcError): - super().__init__(f"An error occurred during the grpc call: {error}.") - - -class NoSuchModelException(Exception): - def __init__(self, model_id: str): - super().__init__(f"Model with id {model_id} not found!") diff --git a/snet/sdk/training/training.py b/snet/sdk/training/training.py index b37bedb..fc04c34 100644 --- a/snet/sdk/training/training.py +++ b/snet/sdk/training/training.py @@ -11,11 +11,11 @@ ) from snet.sdk.utils.call_utils import create_intercept_call_func from snet.sdk.utils.utils import add_to_path, RESOURCES_PATH -from snet.sdk.training.exceptions import ( - WrongDatasetException, - WrongMethodException, - GRPCException, - NoSuchModelException, +from snet.sdk.exceptions import ( + WrongDatasetError, + WrongMethodError, + GRPCError, + NoSuchModelError, ) from snet.sdk.training.responses import ( ModelStatus, @@ -85,9 +85,9 @@ def validate_model_price(self, model_id: str) -> int: "validate_model_price", request_data=validate_model_price_request, ) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e @@ -100,9 +100,9 @@ def train_model_price(self, model_id: str) -> int: ) try: response = self._call_method("train_model_price", request_data=common_request) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e @@ -115,9 +115,9 @@ def delete_model(self, model_id: str) -> ModelStatus: ) try: response = self._call_method("delete_model", request_data=common_request) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e @@ -172,9 +172,9 @@ def get_model(self, model_id: str) -> Model: ) try: response = self._call_method("get_model", request_data=common_request) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e model = Model(response) @@ -226,9 +226,9 @@ def update_model( try: response = self._call_method("update_model", request_data=update_model_request) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e @@ -281,9 +281,9 @@ def request_iter(file): response = self._call_method( "upload_and_validate", request_data=request_iter(f), paid=True ) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e finally: @@ -301,9 +301,9 @@ def train_model(self, model_id: str, price: int) -> ModelStatus: try: response = self._call_method("train_model", request_data=common_request, paid=True) - except GRPCException as e: + except GRPCError as e: if "unable to access model" in str(e): - raise NoSuchModelException(model_id) + raise NoSuchModelError(model_id) else: raise e @@ -317,7 +317,7 @@ def _call_method(self, method_name: str, request_data, paid=False) -> Any: response = getattr(stub, method_name)(request_data) return response except grpc.RpcError as e: - raise GRPCException(e) + raise GRPCError(e) def _get_training_stub(self, paid=False) -> Any: grpc_channel = self.service_client.get_grpc_base_channel() @@ -345,12 +345,12 @@ def _check_method_name(self, method_name: str) -> tuple[str, str]: for method in methods: if method[0] == method_name: return service, method[0] - raise WrongMethodException(method_name) + raise WrongMethodError(method_name) def _check_training(self) -> bool: try: service_methods = self.get_training_metadata().training_methods - except GRPCException: + except GRPCError: return False if len(service_methods.keys()) == 0: return False @@ -402,7 +402,7 @@ def _check_dataset(self, model_id: str, zip_path: str | Path | PurePath) -> None ) if len(failed_checks) > 0: - raise WrongDatasetException(failed_checks) + raise WrongDatasetError(failed_checks) def _get_grpc_channel(self, base_channel: grpc.Channel) -> grpc.Channel: intercept_call_func = create_intercept_call_func(