Skip to content
118 changes: 66 additions & 52 deletions fastapi_mail/email_utils/email_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,16 @@ def __init__(
or "https://gist.githubusercontent.com/Turall/3f32cb57270aed30d0c7f5e0800b2a92/raw/dcd9b47506e9da26d5772ccebf6913343e53cec9/temporary-email-address-domains" # noqa: E501
)
self.redis_enabled = False
self.redis_client = redis_client

if db_provider == "redis":
self.redis_enabled = True
if redis_client:
self.redis_client = redis_client
else:
self.username = username
self.redis_host = redis_host
self.redis_port = redis_port
self.redis_db = redis_db
self.redis_password = redis_password
self.options = options
self.username = username
self.redis_host = redis_host
self.redis_port = redis_port
self.redis_db = redis_db
self.redis_password = redis_password
self.options = options
self.redis_error_msg = "redis is not connected"

def catch_all_check(self):
Expand All @@ -157,48 +155,51 @@ def catch_all_check(self):
f"for class {self.__class__.__name__}"
)

def _get_redis_client(self) -> "aioredis.Redis":
if self.redis_client is None:
raise DBProvaiderError(self.redis_error_msg)
return self.redis_client

async def init_redis(self) -> bool:
if not self.redis_enabled:
raise DBProvaiderError(self.redis_error_msg)
if self.redis_client is None:
# Create new Redis connection pool
if not self.username or not self.redis_password:
self.redis_client = await aioredis.from_url(
url="redis://localhost", encoding="UTF-8", **self.options
)
else:
self.redis_client = await aioredis.from_url(
url=f"redis://{self.username}:{self.redis_password}@localhost:{self.redis_port}/{self.redis_db}", # noqa: E501
encoding="UTF-8",
**self.options,
)
url = f"redis://{self.redis_host}:{self.redis_port}/{self.redis_db}"
self.redis_client = await aioredis.from_url(
url=url,
encoding="UTF-8",
username=self.username,
password=self.redis_password,
**self.options,
)
else:
# Validate that the provided client is an async Redis client
# Validate that the provided client is an async Redis client.
if not isinstance(self.redis_client, aioredis.Redis):
raise DBProvaiderError(
"Provided redis_client must be an async Redis client from redis.asyncio. "
f"Received type: {type(self.redis_client)}. "
"Use: from redis.asyncio import Redis; client = Redis.from_url(...)"
)

temp_counter = await self.redis_client.get("temp_counter")
domain_counter = await self.redis_client.get("domain_counter")
blocked_emails = await self.redis_client.get("email_counter")
redis_client = self._get_redis_client()
temp_counter = await redis_client.get("temp_counter")
domain_counter = await redis_client.get("domain_counter")
blocked_emails = await redis_client.get("email_counter")

if not temp_counter:
await self.redis_client.set("temp_counter", 0)
await redis_client.set("temp_counter", 0)
if not domain_counter:
await self.redis_client.set("domain_counter", 0)
await redis_client.set("domain_counter", 0)
if not blocked_emails:
await self.redis_client.set("email_counter", 0)
await redis_client.set("email_counter", 0)
temp_domains = await self.fetch_temp_email_domains()
check_key = await self.redis_client.hgetall("temp_domains")
check_key = await redis_client.hgetall("temp_domains")
if not check_key:
kwargs = {
domain: await self.redis_client.incr("temp_counter")
domain: await redis_client.incr("temp_counter")
for domain in temp_domains
}
await self.redis_client.hset("temp_domains", mapping=kwargs)
await redis_client.hset("temp_domains", mapping=kwargs)

return True

Expand All @@ -225,56 +226,62 @@ async def fetch_temp_email_domains(self) -> Union[List[str], Any]:
async def blacklist_add_domain(self, domain: str) -> None:
"""Add domain to blacklist"""
if self.redis_enabled:
result = await self.redis_client.hget("blocked_domains", domain)
redis_client = self._get_redis_client()
result = await redis_client.hget("blocked_domains", domain)
if not result:
incr = await self.redis_client.incr("domain_counter")
await self.redis_client.hset("blocked_domains", domain, incr)
incr = await redis_client.incr("domain_counter")
await redis_client.hset("blocked_domains", domain, incr)
else:
self.BLOCKED_DOMAINS.add(domain)

async def blacklist_rm_domain(self, domain: str) -> None:
if self.redis_enabled:
res = await self.redis_client.hdel("blocked_domains", domain)
redis_client = self._get_redis_client()
res = await redis_client.hdel("blocked_domains", domain)
if res:
await self.redis_client.decr("domain_counter")
await redis_client.decr("domain_counter")
else:
self.BLOCKED_DOMAINS.remove(domain)

async def blacklist_add_email(self, email: str) -> None:
"""Add email address to blacklist"""
if self.validate_email(email):
if self.redis_enabled:
blocked_domain = await self.redis_client.hget("blocked_emails", email)
redis_client = self._get_redis_client()
blocked_domain = await redis_client.hget("blocked_emails", email)
if not blocked_domain:
inc = await self.redis_client.incr("email_counter")
await self.redis_client.hset("blocked_emails", email, inc)
inc = await redis_client.incr("email_counter")
await redis_client.hset("blocked_emails", email, inc)
else:
self.BLOCKED_ADDRESSES.add(email)

async def blacklist_rm_email(self, email: str) -> None:
if self.redis_enabled:
res = await self.redis_client.hdel("blocked_emails", email)
redis_client = self._get_redis_client()
res = await redis_client.hdel("blocked_emails", email)
if res:
await self.redis_client.decr("email_counter")
await redis_client.decr("email_counter")
else:
self.BLOCKED_ADDRESSES.remove(email)

async def add_temp_domain(self, domain_lists: List[str]) -> None:
"""Manually add temporary email"""
if self.redis_enabled:
redis_client = self._get_redis_client()
for domain in domain_lists:
temp_email = await self.redis_client.hget("temp_domains", domain)
temp_email = await redis_client.hget("temp_domains", domain)
if not temp_email:
incr = await self.redis_client.incr("temp_counter")
await self.redis_client.hset("temp_domains", domain, incr)
incr = await redis_client.incr("temp_counter")
await redis_client.hset("temp_domains", domain, incr)
else:
self.TEMP_EMAIL_DOMAINS.extend(domain_lists)

async def blacklist_rm_temp(self, domain: str) -> bool:
if self.redis_enabled:
res = await self.redis_client.hdel("temp_domains", domain)
redis_client = self._get_redis_client()
res = await redis_client.hdel("temp_domains", domain)
if res:
await self.redis_client.decr("temp_counter")
await redis_client.decr("temp_counter")
else:
self.TEMP_EMAIL_DOMAINS.remove(domain)
return True
Expand All @@ -285,7 +292,8 @@ async def is_disposable(self, email: str) -> bool:
_, domain = email.split("@")
result = None
if self.redis_enabled:
result = await self.redis_client.hget("temp_domains", domain)
redis_client = self._get_redis_client()
result = await redis_client.hget("temp_domains", domain)
return bool(result)
return domain in self.TEMP_EMAIL_DOMAINS
return False
Expand All @@ -295,7 +303,8 @@ async def is_blocked_domain(self, domain: str) -> bool:
if not self.redis_enabled:
return domain in self.BLOCKED_DOMAINS

blocked_email = await self.redis_client.hget("blocked_domains", domain)
redis_client = self._get_redis_client()
blocked_email = await redis_client.hget("blocked_domains", domain)
return bool(blocked_email)

async def is_blocked_address(self, email: str) -> bool:
Expand All @@ -304,7 +313,8 @@ async def is_blocked_address(self, email: str) -> bool:
if not self.redis_enabled:
return email in self.BLOCKED_ADDRESSES

blocked_domain = await self.redis_client.hget("blocked_emails", email)
redis_client = self._get_redis_client()
blocked_domain = await redis_client.hget("blocked_emails", email)
return bool(blocked_domain)
return False

Expand All @@ -330,31 +340,35 @@ async def check_mx_record(
async def blocked_email_count(self) -> int:
"""count all blocked emails in redis"""
if self.redis_enabled:
result = await self.redis_client.get("email_counter")
redis_client = self._get_redis_client()
result = await redis_client.get("email_counter")
if result is not None:
return result
return len(self.BLOCKED_ADDRESSES)

async def blocked_domain_count(self) -> int:
"""count all blocked domains in redis"""
if self.redis_enabled:
result = await self.redis_client.get("domain_counter")
redis_client = self._get_redis_client()
result = await redis_client.get("domain_counter")
if result is not None:
return result
return len(self.BLOCKED_DOMAINS)

async def temp_email_count(self) -> int:
"""count all temporary emails in redis"""
if self.redis_enabled:
result = await self.redis_client.get("temp_counter")
redis_client = self._get_redis_client()
result = await redis_client.get("temp_counter")
if result is not None:
return result
return len(self.TEMP_EMAIL_DOMAINS)

async def close_connections(self) -> bool:
"""for correctly close connection from redis"""
if self.redis_enabled:
await self.redis_client.close()
redis_client = self._get_redis_client()
await redis_client.aclose() # type: ignore[attr-defined]
return True
raise DBProvaiderError(self.redis_error_msg)

Expand Down
4 changes: 2 additions & 2 deletions fastapi_mail/msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ async def attach_file(self, message: MIMEMultipart, attachment: Any):
part = MIMEBase(
_maintype=file_meta["mime_type"], _subtype=file_meta["mime_subtype"]
)

# If the file-like object has a content-type header,
# use that to determine the MIME type of the attachment
elif hasattr(file, 'headers') and file.headers.get("content-type"):
elif hasattr(file, "headers") and file.headers.get("content-type"):
content_type = file.headers.get("content-type")
if "/" in content_type:
_maintype, _subtype = content_type.split("/", 1)
Expand Down
10 changes: 2 additions & 8 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ async def test_attachement_message(mail_config):

assert len(outbox) == 1
assert mail._payload[1].get_content_maintype() == "text"
assert (
mail._payload[1].__dict__.get("_headers")[0][1]
== "text/plain"
)
assert mail._payload[1].__dict__.get("_headers")[0][1] == "text/plain"


@pytest.mark.asyncio
Expand Down Expand Up @@ -376,10 +373,7 @@ async def test_send_msg_with_alternative_body_and_attachements(mail_config):

assert mail._payload[1].get_content_maintype() == "text"

assert (
mail._payload[1].__dict__.get("_headers")[0][1]
== "text/plain"
)
assert mail._payload[1].__dict__.get("_headers")[0][1] == "text/plain"


@pytest.mark.asyncio
Expand Down
67 changes: 67 additions & 0 deletions tests/test_redis_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from fastapi_mail.email_utils import DefaultChecker


@pytest.mark.asyncio
async def test_redis_checker(redis_checker):
Expand Down Expand Up @@ -31,3 +33,68 @@ async def test_redis_checker(redis_checker):

assert await redis_checker.is_blocked_address(email) is True
assert await redis_checker.check_mx_record(domain) is True


class FakeRedisClient:
async def get(self, key):
return 1

async def hgetall(self, key):
return {"existing": 1}


async def temp_domains():
return ["example.com"]


@pytest.mark.asyncio
async def test_default_redis_connection_uses_localhost_fallback(monkeypatch):
urls = []

async def from_url(url, **kwargs):
urls.append(url)
return FakeRedisClient()

checker = DefaultChecker(db_provider="redis")
monkeypatch.setattr(
"fastapi_mail.email_utils.email_check.aioredis.from_url", from_url
)
monkeypatch.setattr(checker, "fetch_temp_email_domains", temp_domains)

await checker.init_redis()

assert urls == ["redis://localhost:6379/0"]


@pytest.mark.asyncio
async def test_custom_redis_host_overrides_localhost_fallback(monkeypatch):
urls = []

async def from_url(url, **kwargs):
urls.append(url)
return FakeRedisClient()

checker = DefaultChecker(db_provider="redis", redis_host="redis")
monkeypatch.setattr(
"fastapi_mail.email_utils.email_check.aioredis.from_url", from_url
)
monkeypatch.setattr(checker, "fetch_temp_email_domains", temp_domains)

await checker.init_redis()

assert urls == ["redis://redis:6379/0"]


@pytest.mark.asyncio
async def test_existing_redis_client_does_not_create_new_connection(
monkeypatch, redis_checker
):
async def from_url(url, **kwargs):
raise AssertionError("from_url should not be called for an existing client")

monkeypatch.setattr(
"fastapi_mail.email_utils.email_check.aioredis.from_url", from_url
)
monkeypatch.setattr(redis_checker, "fetch_temp_email_domains", temp_domains)

await redis_checker.init_redis()