diff --git a/fastapi_mail/email_utils/email_check.py b/fastapi_mail/email_utils/email_check.py index af535cf..8884c55 100644 --- a/fastapi_mail/email_utils/email_check.py +++ b/fastapi_mail/email_utils/email_check.py @@ -137,18 +137,16 @@ def __init__( or "https://gist.githubusercontent.com/Turall/3f32cb57270aed30d0c7f5e0800b2a92/raw/dcd9b47506e9da26d5772ccebf6913343e53cec9/temporary-email-address-domains" # noqa: E501 ) self.redis_enabled = False + self.redis_client = redis_client if db_provider == "redis": self.redis_enabled = True - if redis_client: - self.redis_client = redis_client - else: - self.username = username - self.redis_host = redis_host - self.redis_port = redis_port - self.redis_db = redis_db - self.redis_password = redis_password - self.options = options + self.username = username + self.redis_host = redis_host + self.redis_port = redis_port + self.redis_db = redis_db + self.redis_password = redis_password + self.options = options self.redis_error_msg = "redis is not connected" def catch_all_check(self): @@ -157,23 +155,25 @@ def catch_all_check(self): f"for class {self.__class__.__name__}" ) + def _get_redis_client(self) -> "aioredis.Redis": + if self.redis_client is None: + raise DBProvaiderError(self.redis_error_msg) + return self.redis_client + async def init_redis(self) -> bool: if not self.redis_enabled: raise DBProvaiderError(self.redis_error_msg) if self.redis_client is None: - # Create new Redis connection pool - if not self.username or not self.redis_password: - self.redis_client = await aioredis.from_url( - url="redis://localhost", encoding="UTF-8", **self.options - ) - else: - self.redis_client = await aioredis.from_url( - url=f"redis://{self.username}:{self.redis_password}@localhost:{self.redis_port}/{self.redis_db}", # noqa: E501 - encoding="UTF-8", - **self.options, - ) + url = f"redis://{self.redis_host}:{self.redis_port}/{self.redis_db}" + self.redis_client = await aioredis.from_url( + url=url, + encoding="UTF-8", + username=self.username, + password=self.redis_password, + **self.options, + ) else: - # Validate that the provided client is an async Redis client + # Validate that the provided client is an async Redis client. if not isinstance(self.redis_client, aioredis.Redis): raise DBProvaiderError( "Provided redis_client must be an async Redis client from redis.asyncio. " @@ -181,24 +181,25 @@ async def init_redis(self) -> bool: "Use: from redis.asyncio import Redis; client = Redis.from_url(...)" ) - temp_counter = await self.redis_client.get("temp_counter") - domain_counter = await self.redis_client.get("domain_counter") - blocked_emails = await self.redis_client.get("email_counter") + redis_client = self._get_redis_client() + temp_counter = await redis_client.get("temp_counter") + domain_counter = await redis_client.get("domain_counter") + blocked_emails = await redis_client.get("email_counter") if not temp_counter: - await self.redis_client.set("temp_counter", 0) + await redis_client.set("temp_counter", 0) if not domain_counter: - await self.redis_client.set("domain_counter", 0) + await redis_client.set("domain_counter", 0) if not blocked_emails: - await self.redis_client.set("email_counter", 0) + await redis_client.set("email_counter", 0) temp_domains = await self.fetch_temp_email_domains() - check_key = await self.redis_client.hgetall("temp_domains") + check_key = await redis_client.hgetall("temp_domains") if not check_key: kwargs = { - domain: await self.redis_client.incr("temp_counter") + domain: await redis_client.incr("temp_counter") for domain in temp_domains } - await self.redis_client.hset("temp_domains", mapping=kwargs) + await redis_client.hset("temp_domains", mapping=kwargs) return True @@ -225,18 +226,20 @@ async def fetch_temp_email_domains(self) -> Union[List[str], Any]: async def blacklist_add_domain(self, domain: str) -> None: """Add domain to blacklist""" if self.redis_enabled: - result = await self.redis_client.hget("blocked_domains", domain) + redis_client = self._get_redis_client() + result = await redis_client.hget("blocked_domains", domain) if not result: - incr = await self.redis_client.incr("domain_counter") - await self.redis_client.hset("blocked_domains", domain, incr) + incr = await redis_client.incr("domain_counter") + await redis_client.hset("blocked_domains", domain, incr) else: self.BLOCKED_DOMAINS.add(domain) async def blacklist_rm_domain(self, domain: str) -> None: if self.redis_enabled: - res = await self.redis_client.hdel("blocked_domains", domain) + redis_client = self._get_redis_client() + res = await redis_client.hdel("blocked_domains", domain) if res: - await self.redis_client.decr("domain_counter") + await redis_client.decr("domain_counter") else: self.BLOCKED_DOMAINS.remove(domain) @@ -244,37 +247,41 @@ async def blacklist_add_email(self, email: str) -> None: """Add email address to blacklist""" if self.validate_email(email): if self.redis_enabled: - blocked_domain = await self.redis_client.hget("blocked_emails", email) + redis_client = self._get_redis_client() + blocked_domain = await redis_client.hget("blocked_emails", email) if not blocked_domain: - inc = await self.redis_client.incr("email_counter") - await self.redis_client.hset("blocked_emails", email, inc) + inc = await redis_client.incr("email_counter") + await redis_client.hset("blocked_emails", email, inc) else: self.BLOCKED_ADDRESSES.add(email) async def blacklist_rm_email(self, email: str) -> None: if self.redis_enabled: - res = await self.redis_client.hdel("blocked_emails", email) + redis_client = self._get_redis_client() + res = await redis_client.hdel("blocked_emails", email) if res: - await self.redis_client.decr("email_counter") + await redis_client.decr("email_counter") else: self.BLOCKED_ADDRESSES.remove(email) async def add_temp_domain(self, domain_lists: List[str]) -> None: """Manually add temporary email""" if self.redis_enabled: + redis_client = self._get_redis_client() for domain in domain_lists: - temp_email = await self.redis_client.hget("temp_domains", domain) + temp_email = await redis_client.hget("temp_domains", domain) if not temp_email: - incr = await self.redis_client.incr("temp_counter") - await self.redis_client.hset("temp_domains", domain, incr) + incr = await redis_client.incr("temp_counter") + await redis_client.hset("temp_domains", domain, incr) else: self.TEMP_EMAIL_DOMAINS.extend(domain_lists) async def blacklist_rm_temp(self, domain: str) -> bool: if self.redis_enabled: - res = await self.redis_client.hdel("temp_domains", domain) + redis_client = self._get_redis_client() + res = await redis_client.hdel("temp_domains", domain) if res: - await self.redis_client.decr("temp_counter") + await redis_client.decr("temp_counter") else: self.TEMP_EMAIL_DOMAINS.remove(domain) return True @@ -285,7 +292,8 @@ async def is_disposable(self, email: str) -> bool: _, domain = email.split("@") result = None if self.redis_enabled: - result = await self.redis_client.hget("temp_domains", domain) + redis_client = self._get_redis_client() + result = await redis_client.hget("temp_domains", domain) return bool(result) return domain in self.TEMP_EMAIL_DOMAINS return False @@ -295,7 +303,8 @@ async def is_blocked_domain(self, domain: str) -> bool: if not self.redis_enabled: return domain in self.BLOCKED_DOMAINS - blocked_email = await self.redis_client.hget("blocked_domains", domain) + redis_client = self._get_redis_client() + blocked_email = await redis_client.hget("blocked_domains", domain) return bool(blocked_email) async def is_blocked_address(self, email: str) -> bool: @@ -304,7 +313,8 @@ async def is_blocked_address(self, email: str) -> bool: if not self.redis_enabled: return email in self.BLOCKED_ADDRESSES - blocked_domain = await self.redis_client.hget("blocked_emails", email) + redis_client = self._get_redis_client() + blocked_domain = await redis_client.hget("blocked_emails", email) return bool(blocked_domain) return False @@ -330,7 +340,8 @@ async def check_mx_record( async def blocked_email_count(self) -> int: """count all blocked emails in redis""" if self.redis_enabled: - result = await self.redis_client.get("email_counter") + redis_client = self._get_redis_client() + result = await redis_client.get("email_counter") if result is not None: return result return len(self.BLOCKED_ADDRESSES) @@ -338,7 +349,8 @@ async def blocked_email_count(self) -> int: async def blocked_domain_count(self) -> int: """count all blocked domains in redis""" if self.redis_enabled: - result = await self.redis_client.get("domain_counter") + redis_client = self._get_redis_client() + result = await redis_client.get("domain_counter") if result is not None: return result return len(self.BLOCKED_DOMAINS) @@ -346,7 +358,8 @@ async def blocked_domain_count(self) -> int: async def temp_email_count(self) -> int: """count all temporary emails in redis""" if self.redis_enabled: - result = await self.redis_client.get("temp_counter") + redis_client = self._get_redis_client() + result = await redis_client.get("temp_counter") if result is not None: return result return len(self.TEMP_EMAIL_DOMAINS) @@ -354,7 +367,8 @@ async def temp_email_count(self) -> int: async def close_connections(self) -> bool: """for correctly close connection from redis""" if self.redis_enabled: - await self.redis_client.close() + redis_client = self._get_redis_client() + await redis_client.aclose() # type: ignore[attr-defined] return True raise DBProvaiderError(self.redis_error_msg) diff --git a/fastapi_mail/msg.py b/fastapi_mail/msg.py index 27f2395..7afc85f 100644 --- a/fastapi_mail/msg.py +++ b/fastapi_mail/msg.py @@ -68,10 +68,10 @@ async def attach_file(self, message: MIMEMultipart, attachment: Any): part = MIMEBase( _maintype=file_meta["mime_type"], _subtype=file_meta["mime_subtype"] ) - + # If the file-like object has a content-type header, # use that to determine the MIME type of the attachment - elif hasattr(file, 'headers') and file.headers.get("content-type"): + elif hasattr(file, "headers") and file.headers.get("content-type"): content_type = file.headers.get("content-type") if "/" in content_type: _maintype, _subtype = content_type.split("/", 1) diff --git a/tests/test_connection.py b/tests/test_connection.py index 272a23a..8876b0a 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -90,10 +90,7 @@ async def test_attachement_message(mail_config): assert len(outbox) == 1 assert mail._payload[1].get_content_maintype() == "text" - assert ( - mail._payload[1].__dict__.get("_headers")[0][1] - == "text/plain" - ) + assert mail._payload[1].__dict__.get("_headers")[0][1] == "text/plain" @pytest.mark.asyncio @@ -376,10 +373,7 @@ async def test_send_msg_with_alternative_body_and_attachements(mail_config): assert mail._payload[1].get_content_maintype() == "text" - assert ( - mail._payload[1].__dict__.get("_headers")[0][1] - == "text/plain" - ) + assert mail._payload[1].__dict__.get("_headers")[0][1] == "text/plain" @pytest.mark.asyncio diff --git a/tests/test_redis_config.py b/tests/test_redis_config.py index 9a70d09..3c9f108 100644 --- a/tests/test_redis_config.py +++ b/tests/test_redis_config.py @@ -1,5 +1,7 @@ import pytest +from fastapi_mail.email_utils import DefaultChecker + @pytest.mark.asyncio async def test_redis_checker(redis_checker): @@ -31,3 +33,68 @@ async def test_redis_checker(redis_checker): assert await redis_checker.is_blocked_address(email) is True assert await redis_checker.check_mx_record(domain) is True + + +class FakeRedisClient: + async def get(self, key): + return 1 + + async def hgetall(self, key): + return {"existing": 1} + + +async def temp_domains(): + return ["example.com"] + + +@pytest.mark.asyncio +async def test_default_redis_connection_uses_localhost_fallback(monkeypatch): + urls = [] + + async def from_url(url, **kwargs): + urls.append(url) + return FakeRedisClient() + + checker = DefaultChecker(db_provider="redis") + monkeypatch.setattr( + "fastapi_mail.email_utils.email_check.aioredis.from_url", from_url + ) + monkeypatch.setattr(checker, "fetch_temp_email_domains", temp_domains) + + await checker.init_redis() + + assert urls == ["redis://localhost:6379/0"] + + +@pytest.mark.asyncio +async def test_custom_redis_host_overrides_localhost_fallback(monkeypatch): + urls = [] + + async def from_url(url, **kwargs): + urls.append(url) + return FakeRedisClient() + + checker = DefaultChecker(db_provider="redis", redis_host="redis") + monkeypatch.setattr( + "fastapi_mail.email_utils.email_check.aioredis.from_url", from_url + ) + monkeypatch.setattr(checker, "fetch_temp_email_domains", temp_domains) + + await checker.init_redis() + + assert urls == ["redis://redis:6379/0"] + + +@pytest.mark.asyncio +async def test_existing_redis_client_does_not_create_new_connection( + monkeypatch, redis_checker +): + async def from_url(url, **kwargs): + raise AssertionError("from_url should not be called for an existing client") + + monkeypatch.setattr( + "fastapi_mail.email_utils.email_check.aioredis.from_url", from_url + ) + monkeypatch.setattr(redis_checker, "fetch_temp_email_domains", temp_domains) + + await redis_checker.init_redis()