|
| 1 | +from pathlib import Path |
| 2 | + |
1 | 3 | from web3._utils.events import get_event_data |
2 | 4 | from eth_abi.codec import ABICodec |
| 5 | +import pickle |
3 | 6 |
|
4 | 7 | from snet.sdk.mpe.payment_channel import PaymentChannel |
5 | 8 | from snet.contracts import get_contract_deployment_block |
6 | 9 |
|
7 | 10 |
|
8 | 11 | BLOCKS_PER_BATCH = 5000 |
| 12 | +CHANNELS_DIR = Path.home().joinpath(".snet", "cache", "mpe") |
9 | 13 |
|
10 | 14 |
|
11 | 15 | class PaymentChannelProvider(object): |
12 | | - def __init__(self, w3, payment_channel_state_service_client, mpe_contract): |
| 16 | + def __init__(self, w3, mpe_contract): |
13 | 17 | self.web3 = w3 |
14 | 18 |
|
15 | 19 | self.mpe_contract = mpe_contract |
16 | 20 | self.event_topics = [self.web3.keccak( |
17 | 21 | text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()] |
18 | 22 | self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow") |
19 | | - self.payment_channel_state_service_client = payment_channel_state_service_client |
20 | | - |
21 | | - def get_past_open_channels(self, account, payment_address, group_id, starting_block_number=0, to_block_number=None): |
22 | | - if to_block_number is None: |
23 | | - to_block_number = self.web3.eth.block_number |
24 | | - |
25 | | - if starting_block_number == 0: |
26 | | - starting_block_number = self.deployment_block |
27 | | - |
| 23 | + self.mpe_address = mpe_contract.contract.address |
| 24 | + self.channels_file = CHANNELS_DIR.joinpath(str(self.mpe_address), "channels.pickle") |
| 25 | + |
| 26 | + def update_cache(self): |
| 27 | + channels = [] |
| 28 | + last_read_block = self.deployment_block |
| 29 | + |
| 30 | + if not self.channels_file.exists(): |
| 31 | + print(f"Channels cache is empty. Caching may take some time when first accessing channels.\nCaching in progress...") |
| 32 | + self.channels_file.parent.mkdir(parents=True, exist_ok=True) |
| 33 | + with open(self.channels_file, "wb") as f: |
| 34 | + empty_dict = { |
| 35 | + "last_read_block": last_read_block, |
| 36 | + "channels": channels |
| 37 | + } |
| 38 | + pickle.dump(empty_dict, f) |
| 39 | + else: |
| 40 | + with open(self.channels_file, "rb") as f: |
| 41 | + load_dict = pickle.load(f) |
| 42 | + last_read_block = load_dict["last_read_block"] |
| 43 | + channels = load_dict["channels"] |
| 44 | + |
| 45 | + current_block_number = self.web3.eth.block_number |
| 46 | + |
| 47 | + if last_read_block < current_block_number: |
| 48 | + new_channels = self._get_all_channels_from_blockchain_logs_to_dicts(last_read_block, current_block_number) |
| 49 | + channels = channels + new_channels |
| 50 | + last_read_block = current_block_number |
| 51 | + |
| 52 | + with open(self.channels_file, "wb") as f: |
| 53 | + dict_to_save = { |
| 54 | + "last_read_block": last_read_block, |
| 55 | + "channels": channels |
| 56 | + } |
| 57 | + pickle.dump(dict_to_save, f) |
| 58 | + |
| 59 | + def _event_data_args_to_dict(self, event_data): |
| 60 | + return { |
| 61 | + "channel_id": event_data["channelId"], |
| 62 | + "sender": event_data["sender"], |
| 63 | + "signer": event_data["signer"], |
| 64 | + "recipient": event_data["recipient"], |
| 65 | + "group_id": event_data["groupId"], |
| 66 | + } |
| 67 | + |
| 68 | + def _get_all_channels_from_blockchain_logs_to_dicts(self, starting_block_number, to_block_number): |
28 | 69 | codec: ABICodec = self.web3.codec |
29 | 70 |
|
30 | 71 | logs = [] |
31 | 72 | from_block = starting_block_number |
32 | 73 | while from_block <= to_block_number: |
33 | 74 | to_block = min(from_block + BLOCKS_PER_BATCH, to_block_number) |
34 | | - logs = logs + self.web3.eth.get_logs({"fromBlock": from_block, "toBlock": to_block, |
35 | | - "address": self.mpe_contract.contract.address, |
36 | | - "topics": self.event_topics}) |
| 75 | + logs = logs + self.web3.eth.get_logs({"fromBlock": from_block, |
| 76 | + "toBlock": to_block, |
| 77 | + "address": self.mpe_address, |
| 78 | + "topics": self.event_topics}) |
37 | 79 | from_block = to_block + 1 |
38 | 80 |
|
39 | 81 | event_abi = self.mpe_contract.contract._find_matching_event_abi(event_name="ChannelOpen") |
40 | | - channels_opened = list(filter( |
41 | | - lambda |
42 | | - channel: (channel.sender == account.address or channel.signer == account.signer_address) and channel.recipient == |
43 | | - payment_address and channel.groupId == group_id, |
44 | | - |
45 | | - [get_event_data(codec, event_abi, l)["args"] for l in logs] |
46 | | - )) |
47 | | - return list(map(lambda channel: PaymentChannel(channel["channelId"], self.web3, account, |
48 | | - self.payment_channel_state_service_client, self.mpe_contract), |
49 | | - channels_opened)) |
50 | 82 |
|
51 | | - def open_channel(self, account, amount, expiration, payment_address, group_id): |
| 83 | + event_data_list = [get_event_data(codec, event_abi, l)["args"] for l in logs] |
| 84 | + channels_opened = list(map(self._event_data_args_to_dict, event_data_list)) |
| 85 | + |
| 86 | + return channels_opened |
| 87 | + |
| 88 | + def _get_channels_from_cache(self): |
| 89 | + self.update_cache() |
| 90 | + with open(self.channels_file, "rb") as f: |
| 91 | + load_dict = pickle.load(f) |
| 92 | + return load_dict["channels"] |
52 | 93 |
|
| 94 | + def get_past_open_channels(self, account, payment_address, group_id, payment_channel_state_service_client): |
| 95 | + |
| 96 | + dict_channels = self._get_channels_from_cache() |
| 97 | + |
| 98 | + channels_opened = list(filter(lambda channel: (channel["sender"] == account.address |
| 99 | + or channel["signer"] == account.signer_address) |
| 100 | + and channel["recipient"] == payment_address |
| 101 | + and channel["group_id"] == group_id, |
| 102 | + dict_channels)) |
| 103 | + |
| 104 | + return list(map(lambda channel: PaymentChannel(channel["channel_id"], |
| 105 | + self.web3, |
| 106 | + account, |
| 107 | + payment_channel_state_service_client, |
| 108 | + self.mpe_contract), |
| 109 | + channels_opened)) |
| 110 | + |
| 111 | + def open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client): |
53 | 112 | receipt = self.mpe_contract.open_channel(account, payment_address, group_id, amount, expiration) |
54 | | - return self._get_newly_opened_channel(receipt, account, payment_address, group_id) |
| 113 | + return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client) |
55 | 114 |
|
56 | | - def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id): |
57 | | - receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, |
58 | | - expiration) |
59 | | - return self._get_newly_opened_channel(receipt, account, payment_address, group_id) |
| 115 | + def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client): |
| 116 | + receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, expiration) |
| 117 | + return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client) |
60 | 118 |
|
61 | | - def _get_newly_opened_channel(self, receipt,account, payment_address, group_id): |
62 | | - open_channels = self.get_past_open_channels(account, payment_address, group_id, receipt["blockNumber"], |
63 | | - receipt["blockNumber"]) |
64 | | - if len(open_channels) == 0: |
| 119 | + def _get_newly_opened_channel(self, account, payment_address, group_id, receipt, payment_channel_state_service_client): |
| 120 | + open_channels = self.get_past_open_channels(account, payment_address, group_id, payment_channel_state_service_client) |
| 121 | + if not open_channels: |
65 | 122 | raise Exception(f"Error while opening channel, please check transaction {receipt.transactionHash.hex()} ") |
66 | | - return open_channels[0] |
| 123 | + return open_channels[-1] |
| 124 | + |
0 commit comments