Skip to content

Commit 9e69738

Browse files
authored
Merge pull request #70 from singnet/development
Cache channels
2 parents 45b9ea3 + 5470c75 commit 9e69738

5 files changed

Lines changed: 105 additions & 47 deletions

File tree

snet/sdk/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import google.protobuf.internal.api_implementation
99

10+
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
11+
1012
with warnings.catch_warnings():
1113
# Suppress the eth-typing package`s warnings related to some new networks
1214
warnings.filterwarnings("ignore", "Network .* does not have a valid ChainId. eth-typing should be "
@@ -76,6 +78,7 @@ def __init__(self, sdk_config: Config, metadata_provider=None):
7678
self.registry_contract = get_contract_object(self.web3, "Registry", _registry_contract_address)
7779

7880
self.account = Account(self.web3, sdk_config, self.mpe_contract)
81+
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)
7982

8083
def create_service_client(self, org_id: str, service_id: str, group_name=None,
8184
payment_channel_management_strategy=None,
@@ -122,7 +125,7 @@ def create_service_client(self, org_id: str, service_id: str, group_name=None,
122125
pb2_module = self.get_module_by_keyword(org_id, service_id, keyword="pb2.py")
123126

124127
service_client = ServiceClient(org_id, service_id, service_metadata, group, service_stub, strategy,
125-
options, self.mpe_contract, self.account, self.web3, pb2_module)
128+
options, self.mpe_contract, self.account, self.web3, pb2_module, self.payment_channel_provider)
126129
return service_client
127130

128131
def get_service_stub(self, org_id: str, service_id: str) -> ServiceStub:

snet/sdk/mpe/mpe_contract.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ def __init__(self, w3, address=None):
88
self.contract = get_contract_object(self.web3, "MultiPartyEscrow")
99
else:
1010
self.contract = get_contract_object(self.web3, "MultiPartyEscrow", address)
11-
self.event_topics = [self.web3.keccak(
12-
text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()]
13-
self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow")
1411

1512
def balance(self, address):
1613
return self.contract.functions.balances(address).call()
Lines changed: 92 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,124 @@
1+
from pathlib import Path
2+
13
from web3._utils.events import get_event_data
24
from eth_abi.codec import ABICodec
5+
import pickle
36

47
from snet.sdk.mpe.payment_channel import PaymentChannel
58
from snet.contracts import get_contract_deployment_block
69

710

811
BLOCKS_PER_BATCH = 5000
12+
CHANNELS_DIR = Path.home().joinpath(".snet", "cache", "mpe")
913

1014

1115
class PaymentChannelProvider(object):
12-
def __init__(self, w3, payment_channel_state_service_client, mpe_contract):
16+
def __init__(self, w3, mpe_contract):
1317
self.web3 = w3
1418

1519
self.mpe_contract = mpe_contract
1620
self.event_topics = [self.web3.keccak(
1721
text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()]
1822
self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow")
19-
self.payment_channel_state_service_client = payment_channel_state_service_client
20-
21-
def get_past_open_channels(self, account, payment_address, group_id, starting_block_number=0, to_block_number=None):
22-
if to_block_number is None:
23-
to_block_number = self.web3.eth.block_number
24-
25-
if starting_block_number == 0:
26-
starting_block_number = self.deployment_block
27-
23+
self.mpe_address = mpe_contract.contract.address
24+
self.channels_file = CHANNELS_DIR.joinpath(str(self.mpe_address), "channels.pickle")
25+
26+
def update_cache(self):
27+
channels = []
28+
last_read_block = self.deployment_block
29+
30+
if not self.channels_file.exists():
31+
print(f"Channels cache is empty. Caching may take some time when first accessing channels.\nCaching in progress...")
32+
self.channels_file.parent.mkdir(parents=True, exist_ok=True)
33+
with open(self.channels_file, "wb") as f:
34+
empty_dict = {
35+
"last_read_block": last_read_block,
36+
"channels": channels
37+
}
38+
pickle.dump(empty_dict, f)
39+
else:
40+
with open(self.channels_file, "rb") as f:
41+
load_dict = pickle.load(f)
42+
last_read_block = load_dict["last_read_block"]
43+
channels = load_dict["channels"]
44+
45+
current_block_number = self.web3.eth.block_number
46+
47+
if last_read_block < current_block_number:
48+
new_channels = self._get_all_channels_from_blockchain_logs_to_dicts(last_read_block, current_block_number)
49+
channels = channels + new_channels
50+
last_read_block = current_block_number
51+
52+
with open(self.channels_file, "wb") as f:
53+
dict_to_save = {
54+
"last_read_block": last_read_block,
55+
"channels": channels
56+
}
57+
pickle.dump(dict_to_save, f)
58+
59+
def _event_data_args_to_dict(self, event_data):
60+
return {
61+
"channel_id": event_data["channelId"],
62+
"sender": event_data["sender"],
63+
"signer": event_data["signer"],
64+
"recipient": event_data["recipient"],
65+
"group_id": event_data["groupId"],
66+
}
67+
68+
def _get_all_channels_from_blockchain_logs_to_dicts(self, starting_block_number, to_block_number):
2869
codec: ABICodec = self.web3.codec
2970

3071
logs = []
3172
from_block = starting_block_number
3273
while from_block <= to_block_number:
3374
to_block = min(from_block + BLOCKS_PER_BATCH, to_block_number)
34-
logs = logs + self.web3.eth.get_logs({"fromBlock": from_block, "toBlock": to_block,
35-
"address": self.mpe_contract.contract.address,
36-
"topics": self.event_topics})
75+
logs = logs + self.web3.eth.get_logs({"fromBlock": from_block,
76+
"toBlock": to_block,
77+
"address": self.mpe_address,
78+
"topics": self.event_topics})
3779
from_block = to_block + 1
3880

3981
event_abi = self.mpe_contract.contract._find_matching_event_abi(event_name="ChannelOpen")
40-
channels_opened = list(filter(
41-
lambda
42-
channel: (channel.sender == account.address or channel.signer == account.signer_address) and channel.recipient ==
43-
payment_address and channel.groupId == group_id,
44-
45-
[get_event_data(codec, event_abi, l)["args"] for l in logs]
46-
))
47-
return list(map(lambda channel: PaymentChannel(channel["channelId"], self.web3, account,
48-
self.payment_channel_state_service_client, self.mpe_contract),
49-
channels_opened))
5082

51-
def open_channel(self, account, amount, expiration, payment_address, group_id):
83+
event_data_list = [get_event_data(codec, event_abi, l)["args"] for l in logs]
84+
channels_opened = list(map(self._event_data_args_to_dict, event_data_list))
85+
86+
return channels_opened
87+
88+
def _get_channels_from_cache(self):
89+
self.update_cache()
90+
with open(self.channels_file, "rb") as f:
91+
load_dict = pickle.load(f)
92+
return load_dict["channels"]
5293

94+
def get_past_open_channels(self, account, payment_address, group_id, payment_channel_state_service_client):
95+
96+
dict_channels = self._get_channels_from_cache()
97+
98+
channels_opened = list(filter(lambda channel: (channel["sender"] == account.address
99+
or channel["signer"] == account.signer_address)
100+
and channel["recipient"] == payment_address
101+
and channel["group_id"] == group_id,
102+
dict_channels))
103+
104+
return list(map(lambda channel: PaymentChannel(channel["channel_id"],
105+
self.web3,
106+
account,
107+
payment_channel_state_service_client,
108+
self.mpe_contract),
109+
channels_opened))
110+
111+
def open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client):
53112
receipt = self.mpe_contract.open_channel(account, payment_address, group_id, amount, expiration)
54-
return self._get_newly_opened_channel(receipt, account, payment_address, group_id)
113+
return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client)
55114

56-
def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id):
57-
receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount,
58-
expiration)
59-
return self._get_newly_opened_channel(receipt, account, payment_address, group_id)
115+
def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client):
116+
receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, expiration)
117+
return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client)
60118

61-
def _get_newly_opened_channel(self, receipt,account, payment_address, group_id):
62-
open_channels = self.get_past_open_channels(account, payment_address, group_id, receipt["blockNumber"],
63-
receipt["blockNumber"])
64-
if len(open_channels) == 0:
119+
def _get_newly_opened_channel(self, account, payment_address, group_id, receipt, payment_channel_state_service_client):
120+
open_channels = self.get_past_open_channels(account, payment_address, group_id, payment_channel_state_service_client)
121+
if not open_channels:
65122
raise Exception(f"Error while opening channel, please check transaction {receipt.transactionHash.hex()} ")
66-
return open_channels[0]
123+
return open_channels[-1]
124+

snet/sdk/service_client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path, find_file_by_keyword
1414

1515
import snet.sdk.generic_client_interceptor as generic_client_interceptor
16-
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
1716

1817

1918
class _ClientCallDetails(
@@ -26,7 +25,7 @@ class _ClientCallDetails(
2625

2726
class ServiceClient:
2827
def __init__(self, org_id, service_id, service_metadata, group, service_stub, payment_strategy,
29-
options, mpe_contract, account, sdk_web3, pb2_module):
28+
options, mpe_contract, account, sdk_web3, pb2_module, payment_channel_provider):
3029
self.org_id = org_id
3130
self.service_id = service_id
3231
self.options = options
@@ -38,9 +37,8 @@ def __init__(self, org_id, service_id, service_metadata, group, service_stub, pa
3837
self.__base_grpc_channel = self._get_grpc_channel()
3938
self.grpc_channel = grpc.intercept_channel(self.__base_grpc_channel,
4039
generic_client_interceptor.create(self._intercept_call))
41-
self.payment_channel_provider = PaymentChannelProvider(sdk_web3,
42-
self._generate_payment_channel_state_service_client(),
43-
mpe_contract)
40+
self.payment_channel_provider = payment_channel_provider
41+
self.payment_channel_state_service_client = self._generate_payment_channel_state_service_client()
4442
self.service = self._generate_grpc_stub(service_stub)
4543
self.pb2_module = importlib.import_module(pb2_module) if isinstance(pb2_module, str) else pb2_module
4644
self.payment_channels = []
@@ -122,7 +120,8 @@ def load_open_channels(self):
122120
payment_address = self.group["payment"]["payment_address"]
123121
group_id = base64.b64decode(str(self.group["group_id"]))
124122
new_payment_channels = self.payment_channel_provider.get_past_open_channels(self.account, payment_address,
125-
group_id, self.last_read_block)
123+
group_id,
124+
self.payment_channel_state_service_client)
126125
self.payment_channels = self.payment_channels + \
127126
self._filter_existing_channels_from_new_payment_channels(new_payment_channels)
128127
self.last_read_block = current_block_number
@@ -150,13 +149,14 @@ def open_channel(self, amount, expiration):
150149
payment_address = self.group["payment"]["payment_address"]
151150
group_id = base64.b64decode(str(self.group["group_id"]))
152151
return self.payment_channel_provider.open_channel(self.account, amount, expiration, payment_address,
153-
group_id)
152+
group_id, self.payment_channel_state_service_client)
154153

155154
def deposit_and_open_channel(self, amount, expiration):
156155
payment_address = self.group["payment"]["payment_address"]
157156
group_id = base64.b64decode(str(self.group["group_id"]))
158157
return self.payment_channel_provider.deposit_and_open_channel(self.account, amount, expiration,
159-
payment_address, group_id)
158+
payment_address, group_id,
159+
self.payment_channel_state_service_client)
160160

161161
def get_price(self):
162162
return self.group["pricing"][0]["price_in_cogs"]

version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.5.0"
1+
__version__ = "3.5.1"

0 commit comments

Comments
 (0)