Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions .cursor/skills/git-handling/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fix commit hook failures, or write a commit message.
- Stash if any changes and pull the latest master `git pull origin master --rebase`.
- Branch name must follow the format `type/description`.
- The type can be one of
(chore|test|setup|feature|fix|build|docs|refactor|release).
(feature|fix|docs|refactor).
- Choose the respective `type` from the changes (staged). If none fit, ask.
- Write a appropriate and concise `description` as per the changes

Expand All @@ -30,14 +30,12 @@ fix commit hook failures, or write a commit message.
4. **Never use** `git commit --no-verify`.
5. **Confirmation**: Must prompt the user for confirmation before
committing the changes.

## Commit

- Commit message must match:
`^(chore|test|setup|feature|fix|build|docs|refactor|release)!?: .+`
- Prefer short, descriptive (e.g. "feature: Add MX/SPF checks").
- From repo root: `git commit -m "message"` or `git commit` for editor.
- If hooks block the commit, fix issues (see linting skill).
6. Follow the commit message pattern:
- Commit message must match:
`^(feature|fix|docs|refactor)!?: .+`
- Prefer short, descriptive (e.g. "feature: Add MX/SPF checks").
- From repo root: `git commit -m "message"` or `git commit` for editor.
- If hooks block the commit, fix issues (see linting skill).

## Branching and push

Expand Down
16 changes: 5 additions & 11 deletions .cursor/skills/linting/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ pre-commit.
1. **venv**: Always run all the python packages from `.venv/bin/`

2. **Default**: Run all linters via pre-commit.
- First run `pre-commit autoupdate` before running the linters.
- From project root: `pre-commit run --all-files`
- Install hooks first if needed: `pre-commit install`
- First run `.venv/bin/pre-commit autoupdate` before running the linters.
- From project root: `.venv/bin/pre-commit run --all-files`
- Install hooks first if needed: `.venv/bin/pre-commit install`

3. **Single tool** (only when the user asks for one):
- Ruff: `ruff check .` then `ruff format .` (or `ruff check --fix`)
Expand All @@ -27,14 +27,8 @@ pre-commit.
- Markdownlint: as in `.pre-commit-config.yaml` for `*.md`

4. **Fix then verify**:
- Organize the import with ruff
- Always use full paths to the ruff executable. Run check first, then format:

```bash
.venv/bin/ruff check . --fix && .venv/bin/ruff format .
```

- After fixing issues, run `pre-commit run --all-files` again.
- Organize the imports
- After fixing issues, run `.venv/bin/pre-commit run --all-files` again.

## Project config

Expand Down
2 changes: 1 addition & 1 deletion src/dkim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .exceptions import DomainPolicyError
from .models import DKIM_MARKER, DKIM_SELECTORS, DKIMVerificationReport
from .spf import get_domain_policy_record
from .utils import get_domain_policy_record

if TYPE_CHECKING:
from dns.resolver import Resolver
Expand Down
11 changes: 6 additions & 5 deletions src/dmarc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .exceptions import DomainPolicyError
from .models import DMARC_MARKER, DMARCVerificationReport
from .spf import get_domain_policy_record
from .utils import get_domain_policy_record

if TYPE_CHECKING:
from dns.resolver import Resolver
Expand All @@ -17,12 +17,13 @@ def extract_dmarc_record_info(
For more strict validation, use checkdmarc (domainaware), magicspoofing (magichk).
"""
try:
dmarc_record = get_domain_policy_record(
if dmarc_record := get_domain_policy_record(
f'_dmarc.{domain}',
DMARC_MARKER,
resolver=resolver,
timeout=timeout,
)
return DMARCVerificationReport(valid=True, record=dmarc_record)
):
return DMARCVerificationReport(valid=True, record=dmarc_record)
except DomainPolicyError:
return DMARCVerificationReport(valid=False, record=None)
pass
return DMARCVerificationReport(valid=False, record=None)
54 changes: 15 additions & 39 deletions src/spf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING

import dns.resolver
from dns.rdatatype import RdataType

from .exceptions import DomainPolicyError
from .models import (
Expand All @@ -12,39 +11,12 @@
SPFRecordInfo,
SPFVerificationReport,
)
from .utils import get_domain_policy_record

if TYPE_CHECKING:
from dns.resolver import Resolver


def _is_policy_version_valid(policy_record: str, marker: str) -> bool:
# The only valid version is the marker and it must be only one instance at the beginning of the record.
version_regex = re.compile(f'^{re.escape(marker)}$|^{re.escape(marker)}', re.IGNORECASE)
match = version_regex.search(policy_record)
if not match or match.start() != 0:
return False
instances = version_regex.findall(policy_record)
return len(instances) == 1


def get_domain_policy_record(
name: str,
marker: str,
resolver: 'Resolver | None' = None,
timeout: int = 5,
) -> str:
res = resolver or dns.resolver.get_default_resolver()
try:
txt_records = res.resolve(qname=name, rdtype=RdataType.TXT, lifetime=timeout)
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.LifetimeTimeout) as e:
raise DomainPolicyError('Domain policy record not found') from e
for record in txt_records:
record_text = ''.join(a.decode('utf-8') for a in record.strings)
if marker in record_text and _is_policy_version_valid(record_text, marker):
return record_text
raise DomainPolicyError('Domain policy record not found')


def _check_catchall(spf_record: str) -> CatchAllSecurityLevel | None:
# RFC 7208 §4.7: -all (fail), ~all (softfail), ?all (neutral), +all/all (none).
catchall_regex = re.compile(r'\s[~\+\-\?]?all\b', re.IGNORECASE)
Expand Down Expand Up @@ -97,6 +69,9 @@ def _extract_includes(spf_record: str, resolver: 'Resolver | None', timeout: int
include_regex = re.compile(r'\binclude:\S+\b', re.IGNORECASE)
max_dns_queries = 10
includes: list[str] = []
if not include_regex.search(spf_record):
return includes

res = resolver or dns.resolver.get_default_resolver()

def _get_includes_recursive(spf: str) -> None:
Expand All @@ -122,14 +97,15 @@ def extract_spf_record_info(domain: str, resolver: 'Resolver | None' = None, tim
If strict validation is required, use pyspf (sdgathman) or magicspoofing (magichk).
"""
try:
spf_record = get_domain_policy_record(domain, SPF_MARKER, resolver=resolver, timeout=timeout)
if spf_record := get_domain_policy_record(domain, SPF_MARKER, resolver=resolver, timeout=timeout):
info = SPFRecordInfo(
record=spf_record,
catchall=_check_catchall(spf_record),
deprecated_mechanism=_check_deprecated_mechanism(spf_record),
ip_addresses=_check_ip_addresses(spf_record),
includes=_extract_includes(spf_record, resolver, timeout),
)
return SPFVerificationReport(valid=True, info=info)
except DomainPolicyError:
return SPFVerificationReport(valid=False, info=None)
info = SPFRecordInfo(
record=spf_record,
catchall=_check_catchall(spf_record),
deprecated_mechanism=_check_deprecated_mechanism(spf_record),
ip_addresses=_check_ip_addresses(spf_record),
includes=_extract_includes(spf_record, resolver, timeout),
)
return SPFVerificationReport(valid=True, info=info)
pass
return SPFVerificationReport(valid=False, info=None)
37 changes: 37 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import re
from typing import TYPE_CHECKING

import dns.resolver
from dns.rdatatype import RdataType

from .exceptions import DomainPolicyError

if TYPE_CHECKING:
from dns.resolver import Resolver


def _is_policy_version_valid(policy_record: str, marker: str) -> bool:
version_regex = re.compile(f'^{re.escape(marker)}$|^{re.escape(marker)}', re.IGNORECASE)
match = version_regex.search(policy_record)
if not match:
return False
instances = version_regex.findall(policy_record)
return len(instances) == 1


def get_domain_policy_record(
name: str,
marker: str,
resolver: 'Resolver | None' = None,
timeout: int = 5,
) -> str:
res = resolver or dns.resolver.get_default_resolver()
try:
txt_records = res.resolve(qname=name, rdtype=RdataType.TXT, lifetime=timeout)
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.LifetimeTimeout) as e:
raise DomainPolicyError('Domain policy record not found') from e
for record in txt_records:
record_text = ''.join(a.decode('utf-8') for a in record.strings)
if marker in record_text and _is_policy_version_valid(record_text, marker):
return record_text
raise DomainPolicyError('Domain policy record not found')
1 change: 1 addition & 0 deletions tests/test_dkim.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ def test_defaults_to_dkim_selectors_list() -> None:
result = extract_dkim_record_info('example.com')
# First selector hit -> valid
assert result.valid is True
assert result.record == _VALID_RECORD
8 changes: 8 additions & 0 deletions tests/test_dmarc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def test_correct_dns_name_and_marker() -> None:
assert result.record == 'v=DMARC1; p=none'


def test_no_record_found() -> None:
with patch(_MOCK_TARGET, return_value='') as mock:
result = extract_dmarc_record_info('example.com')
assert result.valid is False
assert result.record is None
mock.assert_called_once_with('_dmarc.example.com', DMARC_MARKER, resolver=None, timeout=5)


def test_resolver_forwarded() -> None:
sentinel_resolver = MagicMock(spec=Resolver)
with patch(_MOCK_TARGET, return_value='v=DMARC1; p=none') as mock:
Expand Down
83 changes: 14 additions & 69 deletions tests/test_spf.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
from unittest.mock import MagicMock, patch

import dns.resolver
import pytest
from unittest.mock import patch

from src.exceptions import DomainPolicyError
from src.models import CatchAllSecurityLevel
from src.models import SPF_MARKER, CatchAllSecurityLevel
from src.spf import (
_check_catchall,
_check_deprecated_mechanism,
_check_ip_addresses,
_extract_includes,
_is_policy_version_valid,
extract_spf_record_info,
get_domain_policy_record,
)


def test_is_policy_version_valid() -> None:
assert _is_policy_version_valid('v=spf1 include:_spf.google.com', 'v=spf1') is True
assert _is_policy_version_valid('v=spf1', 'v=spf1') is True
assert _is_policy_version_valid('v=spf1 ', 'v=spf1') is True
# DMARC with space after marker matches; semicolon alone does not (version must be at start then EOL or space)
assert _is_policy_version_valid('v=DMARC1 p=none', 'v=DMARC1') is True
assert _is_policy_version_valid('other v=spf1', 'v=spf1') is False
# Regex matches once at start ("v=spf1 ") so implementation reports valid; duplicate tags not detected


def test_check_catchall() -> None:
assert _check_catchall('v=spf1 include:_spf.google.com -all') == CatchAllSecurityLevel.HIGH
assert _check_catchall('v=spf1 ~all') == CatchAllSecurityLevel.MEDIUM
Expand Down Expand Up @@ -55,35 +40,6 @@ def test_extract_includes_no_resolver_cap() -> None:
assert not includes


def test_get_domain_policy_record_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
def raise_nxdomain(*_args: object, **_kwargs: object) -> None:
raise dns.resolver.NXDOMAIN()

monkeypatch.setattr(dns.resolver.Resolver, 'resolve', raise_nxdomain)
with pytest.raises(DomainPolicyError):
get_domain_policy_record('example.com', 'v=spf1', resolver=dns.resolver.Resolver(), timeout=1)


def test_extract_spf_record_info_missing(monkeypatch: pytest.MonkeyPatch) -> None:
def raise_no_answer(*_args: object, **_kwargs: object) -> None:
raise dns.resolver.NoAnswer()

monkeypatch.setattr(dns.resolver.Resolver, 'resolve', raise_no_answer)
report = extract_spf_record_info('example.com', resolver=dns.resolver.Resolver(), timeout=1)
assert report.valid is False
assert report.info is None


def test_is_policy_version_valid_semicolon_delimiter() -> None:
"""Validates the regex fix: v=DMARC1; (semicolon, no space) must match."""
assert _is_policy_version_valid('v=DMARC1;p=reject', 'v=DMARC1') is True
assert _is_policy_version_valid('v=DMARC1; p=reject; rua=mailto:r@x.com', 'v=DMARC1') is True
# DKIM marker with semicolon
assert _is_policy_version_valid('v=DKIM1; k=rsa; p=MII...', 'v=DKIM1') is True
# SPF still works with space
assert _is_policy_version_valid('v=spf1 include:x.com ~all', 'v=spf1') is True


def test_extract_includes_recursive() -> None:
call_count = 0

Expand Down Expand Up @@ -115,40 +71,29 @@ def test_extract_includes_respects_max_dns_queries() -> None:
domains = [f'd{i}.com' for i in range(15)]
record = 'v=spf1 ' + ' '.join(f'include:{d}' for d in domains) + ' -all'

with patch('src.spf.get_domain_policy_record', side_effect=DomainPolicyError('')):
with patch('src.spf.get_domain_policy_record', side_effect=DomainPolicyError('')) as mock:
includes = _extract_includes(record, resolver=None, timeout=1)

assert len(includes) <= 10
assert mock.call_count == 10


def test_extract_spf_record_info_no_record_found() -> None:
with patch('src.spf.get_domain_policy_record', side_effect=DomainPolicyError('')) as mock:
report = extract_spf_record_info('example.com')
assert report.valid is False
assert report.info is None
mock.assert_called_once_with('example.com', SPF_MARKER, resolver=None, timeout=5)


def test_extract_spf_record_info_success() -> None:
spf_record = 'v=spf1 include:_spf.example.com ip4:10.0.0.0/8 ~all'
with patch('src.spf.get_domain_policy_record', return_value=spf_record):
with patch('src.spf.get_domain_policy_record', return_value=spf_record) as mock:
report = extract_spf_record_info('example.com')
assert report.valid is True
assert report.info is not None
assert report.info.record == spf_record
assert report.info.catchall == CatchAllSecurityLevel.MEDIUM
assert report.info.deprecated_mechanism is False
assert report.info.ip_addresses is True


def test_get_domain_policy_record_lifetime_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
def raise_timeout(*_args: object, **_kwargs: object) -> None:
raise dns.resolver.LifetimeTimeout(timeout=5.0, errors=[])

monkeypatch.setattr(dns.resolver.Resolver, 'resolve', raise_timeout)
with pytest.raises(DomainPolicyError):
get_domain_policy_record('example.com', 'v=spf1', resolver=dns.resolver.Resolver(), timeout=5)


def test_get_domain_policy_record_no_matching_marker() -> None:
mock_record = MagicMock()
mock_record.strings = [b'some-other-txt-record']
mock_answer = MagicMock()
mock_answer.__iter__ = lambda self: iter([mock_record])

mock_resolver = MagicMock()
mock_resolver.resolve.return_value = mock_answer
with pytest.raises(DomainPolicyError):
get_domain_policy_record('example.com', 'v=spf1', resolver=mock_resolver, timeout=1)
mock.assert_called()
Loading