77import asyncio
88import time
99import datetime
10+ try :
11+ from urllib import urlparse , unquote_plus , urlencode , quote_plus
12+ except ImportError :
13+ from urllib .parse import urlparse , unquote_plus , urlencode , quote_plus
1014
1115from uamqp import authentication , constants , types , errors
1216from uamqp import (
1317 Message ,
14- Source ,
1518 ConnectionAsync ,
1619 AMQPClientAsync ,
1720 SendClientAsync ,
@@ -37,7 +40,7 @@ class EventHubClientAsync(EventHubClient):
3740 sending events to and receiving events from the Azure Event Hubs service.
3841 """
3942
40- def _create_auth (self , auth_uri , username , password ): # pylint: disable=no-self-use
43+ def _create_auth (self , username = None , password = None ): # pylint: disable=no-self-use
4144 """
4245 Create an ~uamqp.authentication.cbs_auth_async.SASTokenAuthAsync instance to authenticate
4346 the session.
@@ -49,32 +52,13 @@ def _create_auth(self, auth_uri, username, password): # pylint: disable=no-self
4952 :param password: The shared access key.
5053 :type password: str
5154 """
55+ username = username or self ._auth_config ['username' ]
56+ password = password or self ._auth_config ['password' ]
5257 if "@sas.root" in username :
53- return authentication .SASLPlain (self .address .hostname , username , password )
54- return authentication .SASTokenAsync .from_shared_access_key (auth_uri , username , password )
55-
56- def _create_connection_async (self ):
57- """
58- Create a new ~uamqp._async.connection_async.ConnectionAsync instance that will be shared between all
59- AsyncSender/AsyncReceiver clients.
60- """
61- if not self .connection :
62- log .info ("{}: Creating connection with address={}" .format (
63- self .container_id , self .address .geturl ()))
64- self .connection = ConnectionAsync (
65- self .address .hostname ,
66- self .auth ,
67- container_id = self .container_id ,
68- properties = self ._create_properties (),
69- debug = self .debug )
70-
71- async def _close_connection_async (self ):
72- """
73- Close and destroy the connection async.
74- """
75- if self .connection :
76- await self .connection .destroy_async ()
77- self .connection = None
58+ return authentication .SASLPlain (
59+ self .address .hostname , username , password , http_proxy = self .http_proxy )
60+ return authentication .SASTokenAsync .from_shared_access_key (
61+ self .auth_uri , username , password , timeout = 60 , http_proxy = self .http_proxy )
7862
7963 async def _close_clients_async (self ):
8064 """
@@ -85,17 +69,13 @@ async def _close_clients_async(self):
8569 async def _wait_for_client (self , client ):
8670 try :
8771 while client .get_handler_state ().value == 2 :
88- await self . connection . work_async ()
72+ await client . _handler . _connection . work_async () # pylint: disable=protected-access
8973 except Exception as exp : # pylint: disable=broad-except
9074 await client .close_async (exception = exp )
9175
9276 async def _start_client_async (self , client ):
9377 try :
94- await client .open_async (self .connection )
95- started = await client .has_started ()
96- while not started :
97- await self .connection .work_async ()
98- started = await client .has_started ()
78+ await client .open_async ()
9979 except Exception as exp : # pylint: disable=broad-except
10080 await client .close_async (exception = exp )
10181
@@ -108,9 +88,8 @@ async def _handle_redirect(self, redirects):
10888 redirects = [c .redirected for c in self .clients if c .redirected ]
10989 if not all (r .hostname == redirects [0 ].hostname for r in redirects ):
11090 raise EventHubError ("Multiple clients attempting to redirect to different hosts." )
111- self .auth = self ._create_auth (redirects [0 ].address .decode ('utf-8' ), ** self ._auth_config )
112- await self .connection .redirect_async (redirects [0 ], self .auth )
113- await asyncio .gather (* [c .open_async (self .connection ) for c in self .clients ])
91+ self ._process_redirect_uri (redirects [0 ])
92+ await asyncio .gather (* [c .open_async () for c in self .clients ])
11493
11594 async def run_async (self ):
11695 """
@@ -125,7 +104,6 @@ async def run_async(self):
125104 :rtype: list[~azure.eventhub.common.EventHubError]
126105 """
127106 log .info ("{}: Starting {} clients" .format (self .container_id , len (self .clients )))
128- self ._create_connection_async ()
129107 tasks = [self ._start_client_async (c ) for c in self .clients ]
130108 try :
131109 await asyncio .gather (* tasks )
@@ -153,18 +131,21 @@ async def stop_async(self):
153131 log .info ("{}: Stopping {} clients" .format (self .container_id , len (self .clients )))
154132 self .stopped = True
155133 await self ._close_clients_async ()
156- await self ._close_connection_async ()
157134
158135 async def get_eventhub_info_async (self ):
159136 """
160137 Get details on the specified EventHub async.
161138
162139 :rtype: dict
163140 """
164- eh_name = self .address .path .lstrip ('/' )
165- target = "amqps://{}/{}" .format (self .address .hostname , eh_name )
166- async with AMQPClientAsync (target , auth = self .auth , debug = self .debug ) as mgmt_client :
167- mgmt_msg = Message (application_properties = {'name' : eh_name })
141+ alt_creds = {
142+ "username" : self ._auth_config .get ("iot_username" ),
143+ "password" :self ._auth_config .get ("iot_password" )}
144+ try :
145+ mgmt_auth = self ._create_auth (** alt_creds )
146+ mgmt_client = AMQPClientAsync (self .mgmt_target , auth = mgmt_auth , debug = self .debug )
147+ await mgmt_client .open_async ()
148+ mgmt_msg = Message (application_properties = {'name' : self .eh_name })
168149 response = await mgmt_client .mgmt_request_async (
169150 mgmt_msg ,
170151 constants .READ_OPERATION ,
@@ -180,6 +161,8 @@ async def get_eventhub_info_async(self):
180161 output ['partition_count' ] = eh_info [b'partition_count' ]
181162 output ['partition_ids' ] = [p .decode ('utf-8' ) for p in eh_info [b'partition_ids' ]]
182163 return output
164+ finally :
165+ await mgmt_client .close_async ()
183166
184167 def add_async_receiver (self , consumer_group , partition , offset = None , prefetch = 300 , operation = None , loop = None ):
185168 """
@@ -201,10 +184,7 @@ def add_async_receiver(self, consumer_group, partition, offset=None, prefetch=30
201184 path = self .address .path + operation if operation else self .address .path
202185 source_url = "amqps://{}{}/ConsumerGroups/{}/Partitions/{}" .format (
203186 self .address .hostname , path , consumer_group , partition )
204- source = Source (source_url )
205- if offset is not None :
206- source .set_filter (offset .selector ())
207- handler = AsyncReceiver (self , source , prefetch = prefetch , loop = loop )
187+ handler = AsyncReceiver (self , source_url , offset = offset , prefetch = prefetch , loop = loop )
208188 self .clients .append (handler )
209189 return handler
210190
0 commit comments