Skip to content

Commit ef4de59

Browse files
authored
Merge pull request #95 from singnet/freecalls-update
Freecalls update
2 parents 4e6879a + 75c2395 commit ef4de59

11 files changed

Lines changed: 187 additions & 148 deletions

File tree

snet/sdk/__init__.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import warnings
5+
from enum import Enum
56

67
import google.protobuf.internal.api_implementation
78

@@ -26,18 +27,28 @@
2627
from snet.sdk.client_lib_generator import ClientLibGenerator
2728
from snet.sdk.mpe.mpe_contract import MPEContract
2829
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
29-
from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy
30+
from snet.sdk.payment_strategies.default_payment_strategy import *
3031
from snet.sdk.service_client import ServiceClient
3132
from snet.sdk.storage_provider.storage_provider import StorageProvider
3233
from snet.sdk.custom_typing import ModuleName, ServiceStub
33-
from snet.sdk.utils.utils import (bytes32_to_str, find_file_by_keyword,
34-
type_converter)
34+
from snet.sdk.utils.utils import (
35+
bytes32_to_str,
36+
find_file_by_keyword,
37+
type_converter
38+
)
3539

3640
google.protobuf.internal.api_implementation.Type = lambda: 'python'
3741
_sym_db = _symbol_database.Default()
3842
_sym_db.RegisterMessage = lambda x: None
3943

4044

45+
class PaymentStrategyType(Enum):
46+
PAID_CALL = PaidCallPaymentStrategy
47+
FREE_CALL = FreeCallPaymentStrategy
48+
PREPAID_CALL = PrePaidPaymentStrategy
49+
DEFAULT = DefaultPaymentStrategy
50+
51+
4152
class SnetSDK:
4253
"""Base Snet SDK"""
4354

@@ -91,8 +102,9 @@ def __init__(self, sdk_config: Config, metadata_provider=None):
91102
def create_service_client(self,
92103
org_id: str,
93104
service_id: str,
94-
group_name=None,
95-
payment_strategy=None,
105+
group_name: str=None,
106+
payment_strategy: PaymentStrategy = None,
107+
payment_strategy_type: PaymentStrategyType=PaymentStrategyType.DEFAULT,
96108
address=None,
97109
options=None,
98110
concurrent_calls: int = 1):
@@ -118,15 +130,14 @@ def create_service_client(self,
118130
print("Generating client library...")
119131
self.lib_generator.generate_client_library()
120132

121-
if payment_strategy is None:
122-
payment_strategy = DefaultPaymentStrategy(
123-
concurrent_calls=concurrent_calls
124-
)
125-
126133
if options is None:
127134
options = dict()
128135
options['user_address'] = address if address else ""
129136
options['concurrency'] = self._sdk_config.get("concurrency", True)
137+
options['concurrent_calls'] = concurrent_calls
138+
139+
if payment_strategy is None:
140+
payment_strategy = payment_strategy_type.value()
130141

131142
service_metadata = self._metadata_provider.enhance_service_metadata(
132143
org_id, service_id
@@ -137,7 +148,8 @@ def create_service_client(self,
137148

138149
pb2_module = self.get_module_by_keyword(keyword="pb2.py")
139150
_service_client = ServiceClient(org_id, service_id, service_metadata,
140-
group, service_stubs, payment_strategy,
151+
group, service_stubs,
152+
payment_strategy,
141153
options, self.mpe_contract,
142154
self.account, self.web3, pb2_module,
143155
self.payment_channel_provider,

snet/sdk/concurrency_manager.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import grpc
44
import web3
55

6-
from snet.sdk.service_client import ServiceClient
76
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path
87

98

109
class ConcurrencyManager:
11-
def __init__(self, concurrent_calls: int):
10+
def __init__(self, concurrent_calls: int=1):
1211
self.__concurrent_calls: int = concurrent_calls
1312
self.__token: str = ''
1413
self.__planned_amount: int = 0
@@ -18,14 +17,18 @@ def __init__(self, concurrent_calls: int):
1817
def concurrent_calls(self) -> int:
1918
return self.__concurrent_calls
2019

20+
@concurrent_calls.setter
21+
def concurrent_calls(self, concurrent_calls: int):
22+
self.__concurrent_calls = concurrent_calls
23+
2124
def get_token(self, service_client, channel, service_call_price):
2225
if len(self.__token) == 0:
2326
self.__token = self.__get_token(service_client, channel, service_call_price)
2427
elif self.__used_amount >= self.__planned_amount:
2528
self.__token = self.__get_token(service_client, channel, service_call_price, new_token=True)
2629
return self.__token
2730

28-
def __get_token(self, service_client: ServiceClient, channel, service_call_price, new_token=False):
31+
def __get_token(self, service_client, channel, service_call_price, new_token=False):
2932
if not new_token:
3033
amount = channel.state["last_signed_amount"]
3134
if amount != 0:
@@ -47,13 +50,13 @@ def __get_token(self, service_client: ServiceClient, channel, service_call_price
4750
self.__planned_amount = token_reply.planned_amount
4851
return token_reply.token
4952

50-
def __get_stub_for_get_token(self, service_client: ServiceClient):
53+
def __get_stub_for_get_token(self, service_client):
5154
grpc_channel = service_client.get_grpc_base_channel()
5255
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
5356
token_service_pb2_grpc = importlib.import_module("token_service_pb2_grpc")
5457
return token_service_pb2_grpc.TokenServiceStub(grpc_channel)
5558

56-
def __get_token_for_amount(self, service_client: ServiceClient, channel, amount):
59+
def __get_token_for_amount(self, service_client, channel, amount):
5760
nonce = channel.state["nonce"]
5861
stub = self.__get_stub_for_get_token(service_client)
5962
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from snet.sdk.concurrency_manager import ConcurrencyManager
21
from snet.sdk.payment_strategies.freecall_payment_strategy import FreeCallPaymentStrategy
32
from snet.sdk.payment_strategies.paidcall_payment_strategy import PaidCallPaymentStrategy
43
from snet.sdk.payment_strategies.prepaid_payment_strategy import PrePaidPaymentStrategy
@@ -7,26 +6,22 @@
76

87
class DefaultPaymentStrategy(PaymentStrategy):
98

10-
def __init__(self, concurrent_calls: int = 1):
11-
self.concurrent_calls = concurrent_calls
12-
self.concurrencyManager = ConcurrencyManager(concurrent_calls)
9+
def __init__(self):
1310
self.channel = None
1411

15-
def set_concurrency_token(self, token):
16-
self.concurrencyManager.__token = token
17-
1812
def set_channel(self, channel):
1913
self.channel = channel
2014

2115
def get_payment_metadata(self, service_client):
2216
free_call_payment_strategy = FreeCallPaymentStrategy()
2317

24-
if free_call_payment_strategy.is_free_call_available(service_client):
18+
if free_call_payment_strategy.get_free_calls_available(service_client) > 0:
2519
metadata = free_call_payment_strategy.get_payment_metadata(service_client)
2620
else:
2721
if service_client.get_concurrency_flag():
28-
payment_strategy = PrePaidPaymentStrategy(self.concurrencyManager)
29-
metadata = payment_strategy.get_payment_metadata(service_client, self.channel)
22+
concurrent_calls = service_client.get_concurrent_calls()
23+
payment_strategy = PrePaidPaymentStrategy(concurrent_calls)
24+
metadata = payment_strategy.get_payment_metadata(service_client)
3025
else:
3126
payment_strategy = PaidCallPaymentStrategy()
3227
metadata = payment_strategy.get_payment_metadata(service_client)
@@ -37,5 +32,6 @@ def get_price(self, service_client):
3732
pass
3833

3934
def get_concurrency_token_and_channel(self, service_client):
40-
payment_strategy = PrePaidPaymentStrategy(self.concurrencyManager)
35+
concurrent_calls = service_client.get_concurrent_calls()
36+
payment_strategy = PrePaidPaymentStrategy(concurrent_calls)
4137
return payment_strategy.get_concurrency_token_and_channel(service_client)

snet/sdk/payment_strategies/freecall_payment_strategy.py

Lines changed: 52 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,99 @@
11
import importlib
2-
from urllib.parse import urlparse
32

43
import grpc
5-
from grpc import _channel
64
import web3
75

86
from snet.sdk.payment_strategies.payment_strategy import PaymentStrategy
9-
from snet.sdk.resources.root_certificate import certificate
107
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path
118

129
class FreeCallPaymentStrategy(PaymentStrategy):
1310

14-
def is_free_call_available(self, service_client) -> bool:
15-
try:
16-
self._user_address = service_client.options["user_address"]
17-
self._free_call_token, self._token_expiry_date_block = self.get_free_call_token_details(service_client)
11+
def __init__(self):
12+
self._user_address = None
13+
self._free_call_token = None
14+
self._token_expiration_block = None
1815

19-
if not self._free_call_token:
20-
return False
16+
def get_free_calls_available(self, service_client) -> int:
17+
if not self._user_address:
18+
self._user_address = service_client.account.signer_address
2119

22-
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
23-
state_service_pb2 = importlib.import_module("state_service_pb2")
20+
current_block_number = service_client.get_current_block_number()
2421

25-
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
26-
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
22+
if (not self._free_call_token or
23+
not self._token_expiration_block or
24+
current_block_number > self._token_expiration_block):
25+
self._free_call_token, self._token_expiration_block = self.get_free_call_token_details(service_client)
2726

28-
signature, current_block_number = self.generate_signature(service_client)
27+
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
28+
state_service_pb2 = importlib.import_module("state_service_pb2")
2929

30-
request = state_service_pb2.FreeCallStateRequest()
31-
request.user_address = self._user_address
32-
request.token_for_free_call = self._free_call_token
33-
request.token_expiry_date_block = self._token_expiry_date_block
34-
request.signature = signature
35-
request.current_block = current_block_number
30+
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
31+
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
3632

37-
channel = self.select_channel(service_client)
33+
signature, _ = self.generate_signature(service_client, current_block_number)
34+
request = state_service_pb2.FreeCallStateRequest(
35+
address=self._user_address,
36+
free_call_token=self._free_call_token,
37+
signature=signature,
38+
current_block=current_block_number
39+
)
3840

39-
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
41+
channel = service_client.get_grpc_base_channel()
42+
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
43+
44+
try:
4045
response = stub.GetFreeCallsAvailable(request)
41-
if response.free_calls_available > 0:
42-
return True
43-
return False
46+
return response.free_calls_available
4447
except grpc.RpcError as e:
4548
if self._user_address:
4649
print(f"Warning: {e.details()}")
47-
return False
48-
except Exception as e:
49-
return False
50+
return 0
5051

5152
def get_payment_metadata(self, service_client) -> list:
53+
if self.get_free_calls_available(service_client) <= 0:
54+
raise Exception(f"Free calls limit for address {self._user_address} has expired. Please use another payment strategy")
5255
signature, current_block_number = self.generate_signature(service_client)
5356
metadata = [("snet-free-call-auth-token-bin", self._free_call_token),
54-
("snet-free-call-token-expiry-block", str(self._token_expiry_date_block)),
5557
("snet-payment-type", "free-call"),
56-
("snet-free-call-user-id", self._user_address),
58+
("snet-free-call-user-address", self._user_address),
5759
("snet-current-block-number", str(current_block_number)),
5860
("snet-payment-channel-signature-bin", signature)]
5961

6062
return metadata
6163

62-
def select_channel(self, service_client) -> _channel.Channel:
63-
_, _, _, daemon_endpoint = service_client.get_service_details()
64-
endpoint_object = urlparse(daemon_endpoint)
65-
if endpoint_object.port is not None:
66-
channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port)
67-
else:
68-
channel_endpoint = endpoint_object.hostname
69-
70-
if endpoint_object.scheme == "http":
71-
channel = grpc.insecure_channel(channel_endpoint)
72-
elif endpoint_object.scheme == "https":
73-
channel = grpc.secure_channel(channel_endpoint, grpc.ssl_channel_credentials(root_certificates=certificate))
74-
else:
75-
raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme))
76-
return channel
77-
78-
def generate_signature(self, service_client) -> tuple[bytes, int]:
64+
def generate_signature(self, service_client, current_block_number=None, with_token=True) -> tuple[bytes, int]:
65+
if not current_block_number:
66+
current_block_number = service_client.get_current_block_number()
7967
org_id, service_id, group_id, _ = service_client.get_service_details()
8068

81-
if self._token_expiry_date_block == 0 or len(self._user_address) == 0 or len(self._free_call_token) == 0:
82-
raise Exception(
83-
"You are using default 'FreeCallPaymentStrategy' to use this strategy you need to pass "
84-
"'free_call_auth_token-bin','user_address','free-call-token-expiry-block' in config")
69+
message_types = ["string", "string", "string", "string", "string", "uint256", "bytes32"]
70+
message_values = ["__prefix_free_trial", self._user_address, org_id, service_id, group_id,
71+
current_block_number, self._free_call_token]
8572

86-
current_block_number = service_client.get_current_block_number()
73+
if not with_token:
74+
message_types = message_types[:-1]
75+
message_values = message_values[:-1]
8776

88-
message = web3.Web3.solidity_keccak(
89-
["string", "string", "string", "string", "string", "uint256", "bytes32"],
90-
["__prefix_free_trial", self._user_address, org_id, service_id, group_id, current_block_number,
91-
self._free_call_token]
92-
)
77+
message = web3.Web3.solidity_keccak(message_types, message_values)
9378
return service_client.generate_signature(message), current_block_number
9479

95-
def get_free_call_token_details(self, service_client) -> tuple[bytes, int]:
80+
def get_free_call_token_details(self, service_client, current_block_number=None) -> tuple[bytes, int]:
81+
82+
signature, current_block_number = self.generate_signature(service_client, current_block_number, with_token=False)
83+
9684
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
9785
state_service_pb2 = importlib.import_module("state_service_pb2")
9886

99-
request = state_service_pb2.GetFreeCallTokenRequest()
100-
request.address = self._user_address
87+
request = state_service_pb2.GetFreeCallTokenRequest(
88+
address=self._user_address,
89+
signature=signature,
90+
current_block=current_block_number
91+
)
10192

10293
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
10394
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
10495

105-
channel = self.select_channel(service_client)
96+
channel = service_client.get_grpc_base_channel()
10697
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
10798
response = stub.GetFreeCallToken(request)
10899

snet/sdk/payment_strategies/prepaid_payment_strategy.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
class PrePaidPaymentStrategy(PaymentStrategy):
66

7-
def __init__(self, concurrency_manager: ConcurrencyManager,
8-
block_offset: int = 240, call_allowance: int = 1):
9-
self.concurrency_manager = concurrency_manager
7+
def __init__(self, concurrent_calls: int=1, block_offset: int = 240, call_allowance: int = 1):
8+
self.concurrency_manager = ConcurrencyManager(concurrent_calls)
109
self.block_offset = block_offset
1110
self.call_allowance = call_allowance
1211

1312
def get_price(self, service_client):
1413
return service_client.get_price() * self.concurrency_manager.concurrent_calls
1514

16-
def get_payment_metadata(self, service_client, channel):
17-
if channel is None:
18-
channel = self.select_channel(service_client)
15+
def set_concurrent_calls(self, concurrent_calls):
16+
self.concurrency_manager.concurrent_calls = concurrent_calls
17+
18+
def get_payment_metadata(self, service_client):
19+
channel = self.select_channel(service_client)
1920
token = self.concurrency_manager.get_token(service_client, channel, self.get_price(service_client))
2021
metadata = [
2122
("snet-payment-type", "prepaid-call"),

0 commit comments

Comments
 (0)