diff --git a/.github/README.md b/.github/README.md new file mode 100644 index 00000000..2a048cff --- /dev/null +++ b/.github/README.md @@ -0,0 +1,65 @@ +# GitHub Actions and Workflows + +This directory contains the CI/CD configuration for **network_wrangler**: workflows that run on pull requests and pushes, plus reusable composite actions used by those workflows. + +--- + +## Workflows + +Workflows live in [`.github/workflows/`](workflows/). They run on different triggers and coordinate linting, testing, docs, benchmarks, and releases. + +| Workflow | Trigger | Lint | Format | Tests | Benchmark | Docs | Coverage | +|----------|---------|:----:|:------:|:-----:|:---------:|:----:|:--------:| +| **PR Checks** (`pullrequest.yml`) | `pull_request` (opened, synchronize, reopened) | ✓ | ✓ (fix) | ✓ | ✓ | ✓ | ✓ | +| **CI** (`push.yml`) | `push` to `main` or `develop` | ✓ | ✓ (check) | ✓ | ✓ | ✓ | — | +| **Prepare Release** (`prepare-release.yml`) | Release **created** or manual | — | — | — | — | — | — | +| **Publish Release** (`publish.yml`) | Release **published** or manual | — | — | — | — | ✓ | — | +| **Clean Documentation** (`clean-docs.yml`) | Branch/tag **deleted** or PR **closed** | — | — | — | — | ✓ (delete) | — | + +- **Lint**: `ruff check` (PR: auto-fix and commit; Push: check only). +- **Format**: `ruff format` (PR: apply fixes; Push: check only). +- **Tests**: pytest on Python 3.10–3.13 +- **Docs**: build and deploy to GitHub Pages (PR/Push/Release); **Clean** removes a version when a branch is deleted or PR closed. +- **Benchmark**: run and compare benchmarks. +- **Coverage**: post coverage comment on PR (when base is `main`/`develop`). + +### Workflow details + +- **PR Checks** + - Lint job can auto-commit formatting/lint fixes to the PR branch (with `[skip ci]`). + - Tests run in a matrix (3.10–3.13); only 3.13 produces coverage and benchmark artifacts. + - Benchmark and coverage jobs run on Python 3.13 and only when the PR base is `main` or `develop`. + - Docs are built per-PR branch; a comment with the docs URL is posted when the PR is opened. + +- **Push (main/develop)** + - Same test matrix and artifact strategy. + - Benchmark comparison is only on Python 3.13 and against the previous commit on the branch. + - Docs are deployed for the pushed branch name. + +- **Releases** + - **Prepare**: runs on release *created* (or manual); ensures version matches tag, publishes to TestPyPI, and verifies install. + - **Publish**: runs on release *published* as a pre-release or release (or manual); publishes to PyPI and then deploys release docs. + +- **Clean docs** + - Uses `get-branch-name` to resolve the branch/tag from the event, then deletes that version from the docs site (skips `main` and `develop`). + +--- + +## Reusable Actions + +Reusable actions live in [`.github/actions/`](actions/). Workflows call them with `uses: ./.github/actions/`. + +| Action | Purpose | +|--------|---------| +| **setup-python-uv** | Sets up the requested Python version, installs [uv](https://github.com/astral-sh/uv), and caches UV packages (keyed by `pyproject.toml`). Used by lint, test, docs, and benchmark jobs. | +| **get-branch-name** | Outputs a normalized branch (or tag) name from the GitHub event (`push`, `pull_request`, `delete`, etc.). Used by docs and clean-docs workflows. | +| **build-docs** | Installs deps with `.[docs]`, runs `mike deploy` for the given branch name, and updates the `latest` alias when the branch is `main`. | +| **compare-benchmarks** | Compares `benchmark.json` either to the previous commit (`push`) or to the base branch (`pr`). Commits `benchmark.json` to the branch and, for PRs, posts a comment with the comparison (and regression warning if applicable). | +| **post-coverage** | Downloads the `coverage-py3.13` artifact, normalizes paths into a `coverage/` directory, and uses `MishaKav/pytest-coverage-comment` to post a coverage comment on the PR. | + +--- + +## Other contents + +- **Issue templates** ([`ISSUE_TEMPLATE/`](ISSUE_TEMPLATE/)) – Templates for bugs, features, docs, performance, and chores. +- **Pull request template** ([`pull_request_template.md`](pull_request_template.md)) – Default body for new pull requests. diff --git a/.github/actions/build-docs/action.yml b/.github/actions/build-docs/action.yml new file mode 100644 index 00000000..9e458472 --- /dev/null +++ b/.github/actions/build-docs/action.yml @@ -0,0 +1,47 @@ +name: 'Build and Deploy Docs' +description: 'Build and deploy documentation using mike' +inputs: + python-version: + description: 'Python version to use' + required: true + default: '3.13' + branch-name: + description: 'Branch name for docs version' + required: true + +runs: + using: 'composite' + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: ${{ inputs.python-version }} + + - name: Configure Git user + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + shell: bash + + - name: Install dependencies + run: | + uv pip install --system -e .[docs] + shell: bash + + - name: Build and deploy docs + continue-on-error: true + run: | + mike deploy --push ${{ inputs.branch-name }} + shell: bash + + - name: Update latest docs alias + if: inputs.branch-name == 'main' + run: | + mike alias ${{ inputs.branch-name }} latest --update-aliases --push + shell: bash + diff --git a/.github/actions/compare-benchmarks/action.yml b/.github/actions/compare-benchmarks/action.yml new file mode 100644 index 00000000..ab3ea2ea --- /dev/null +++ b/.github/actions/compare-benchmarks/action.yml @@ -0,0 +1,203 @@ +name: 'Compare Benchmarks' +description: 'Compare benchmark results between branches or commits and commit benchmark.json to the branch' +inputs: + benchmark-json-path: + description: 'Path to benchmark.json file' + required: false + default: 'benchmark.json' + comparison-type: + description: 'Type of comparison: push (compare to previous commit) or pr (compare to base branch)' + required: true + base-branch: + description: 'Base branch name for PR comparisons' + required: false + github-token: + description: 'GitHub token for posting comments and committing' + required: true + pr-number: + description: 'Pull request number for posting comments' + required: false + alert-threshold: + description: 'Alert threshold percentage (e.g., 125 means 25% slower triggers alert)' + required: false + default: '125' + python-version: + description: 'Python version for installing pytest-benchmark' + required: false + default: '3.13' + +runs: + using: 'composite' + steps: + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: ${{ inputs.python-version }} + + - name: Install pytest-benchmark + shell: bash + run: | + uv pip install --system pytest-benchmark + + - name: Compare benchmarks (Push - previous commit) + id: compare-push + if: inputs.comparison-type == 'push' + shell: bash + run: | + set -e + + CURRENT_COMMIT=$(git rev-parse HEAD) + PREVIOUS_COMMIT=$(git rev-parse HEAD~1 2>/dev/null || echo "") + + if [ -z "$PREVIOUS_COMMIT" ]; then + echo "No previous commit found. This appears to be the first commit with benchmarks." + echo "comparison_result=No previous commit to compare against" >> $GITHUB_OUTPUT + exit 0 + fi + + # Check if current benchmark.json exists + if [ ! -f "${{ inputs.benchmark-json-path }}" ]; then + echo "ERROR: Current benchmark.json not found at ${{ inputs.benchmark-json-path }}" + exit 1 + fi + + # Try to get previous commit's benchmark.json + PREV_BENCHMARK=$(git show ${PREVIOUS_COMMIT}:${{ inputs.benchmark-json-path }} 2>/dev/null || echo "") + + if [ -z "$PREV_BENCHMARK" ]; then + echo "No previous benchmark.json found in commit ${PREVIOUS_COMMIT}. This is the first benchmark run." + echo "comparison_result=No previous benchmark to compare against" >> $GITHUB_OUTPUT + exit 0 + fi + + # Save previous benchmark to temp file + echo "$PREV_BENCHMARK" > /tmp/previous_benchmark.json + + # Compare benchmarks + echo "Comparing current commit (${CURRENT_COMMIT:0:7}) to previous commit (${PREVIOUS_COMMIT:0:7})" + COMPARE_OUTPUT=$(pytest-benchmark compare /tmp/previous_benchmark.json ${{ inputs.benchmark-json-path }} 2>&1 || echo "Comparison completed") + echo "$COMPARE_OUTPUT" + + echo "comparison_result<> $GITHUB_OUTPUT + echo "$COMPARE_OUTPUT" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + - name: Compare benchmarks (PR - base branch) + id: compare-pr + if: inputs.comparison-type == 'pr' + shell: bash + run: | + set -e + + BASE_BRANCH="${{ inputs.base-branch }}" + CURRENT_COMMIT=$(git rev-parse HEAD) + + if [ -z "$BASE_BRANCH" ]; then + echo "ERROR: Base branch not specified for PR comparison" + exit 1 + fi + + # Check if current benchmark.json exists + if [ ! -f "${{ inputs.benchmark-json-path }}" ]; then + echo "ERROR: Current benchmark.json not found at ${{ inputs.benchmark-json-path }}" + exit 1 + fi + + # Fetch base branch with full history + git fetch origin ${BASE_BRANCH} --depth=100 || git fetch origin ${BASE_BRANCH} + + # Get base branch's latest commit + BASE_COMMIT=$(git rev-parse origin/${BASE_BRANCH} 2>/dev/null || echo "") + + if [ -z "$BASE_COMMIT" ]; then + echo "ERROR: Could not find base branch ${BASE_BRANCH}" + exit 1 + fi + + # Try to get base branch's benchmark.json from the latest commit + BASE_BENCHMARK=$(git show origin/${BASE_BRANCH}:${{ inputs.benchmark-json-path }} 2>/dev/null || echo "") + + if [ -z "$BASE_BENCHMARK" ]; then + echo "WARNING: No benchmark.json found in base branch ${BASE_BRANCH}. Cannot compare." + echo "This PR will be the baseline for future comparisons." + echo "comparison_result=No benchmark.json found in base branch" >> $GITHUB_OUTPUT + echo "has_regression=false" >> $GITHUB_OUTPUT + exit 0 + fi + + # Save base benchmark to temp file + echo "$BASE_BENCHMARK" > /tmp/base_benchmark.json + + # Compare benchmarks + echo "Comparing PR branch (${CURRENT_COMMIT:0:7}) to base branch ${BASE_BRANCH} (${BASE_COMMIT:0:7})" + COMPARE_OUTPUT=$(pytest-benchmark compare /tmp/base_benchmark.json ${{ inputs.benchmark-json-path }} 2>&1 || echo "Comparison completed") + echo "$COMPARE_OUTPUT" + + echo "comparison_result<> $GITHUB_OUTPUT + echo "$COMPARE_OUTPUT" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + + # Check for regressions (look for significant slowdown indicators) + if echo "$COMPARE_OUTPUT" | grep -qiE "slower|regression|worse"; then + echo "has_regression=true" >> $GITHUB_OUTPUT + else + echo "has_regression=false" >> $GITHUB_OUTPUT + fi + + - name: Commit benchmark.json to branch + shell: bash + run: | + set -e + + # Configure git + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + + # Check if benchmark.json exists and has changes + if [ ! -f "${{ inputs.benchmark-json-path }}" ]; then + echo "No benchmark.json to commit" + exit 0 + fi + + # Check if file is already committed and unchanged + if git diff --quiet HEAD -- "${{ inputs.benchmark-json-path }}" 2>/dev/null; then + echo "benchmark.json is already committed and unchanged" + exit 0 + fi + + # Add and commit benchmark.json + git add "${{ inputs.benchmark-json-path }}" + git commit -m "ci: Update benchmark results [skip ci]" || echo "No changes to commit" + + # Push to current branch + git push || echo "Push failed (may not have permissions or branch may be protected)" + + - name: Post PR comment with comparison + id: post-comment + if: inputs.comparison-type == 'pr' && inputs.pr-number != '' + uses: actions/github-script@v7 + with: + github-token: ${{ inputs.github-token }} + script: | + const comparison = `${{ steps.compare-pr.outputs.comparison_result }}`; + const hasRegression = '${{ steps.compare-pr.outputs.has_regression }}' === 'true'; + + let comment = '## 📊 Benchmark Comparison\n\n'; + comment += `Comparing PR branch to base branch \`${{ inputs.base-branch }}\`\n\n`; + + if (comparison && comparison.trim() && !comparison.includes('No benchmark')) { + comment += '```\n' + comparison + '\n```\n'; + } else { + comment += 'No previous benchmark found in base branch. This will serve as the baseline.\n'; + } + + if (hasRegression === 'true') { + comment += '\n⚠️ **Performance regression detected!**'; + } + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: ${{ inputs.pr-number }}, + body: comment, + }); diff --git a/.github/actions/get-branch-name/action.yml b/.github/actions/get-branch-name/action.yml new file mode 100644 index 00000000..e593161b --- /dev/null +++ b/.github/actions/get-branch-name/action.yml @@ -0,0 +1,48 @@ +name: 'Get Branch Name' +description: 'Get normalized branch name from GitHub event context' +outputs: + branch-name: + description: 'Normalized branch name (without refs/heads/ prefix)' + value: ${{ steps.branch-name.outputs.value }} + +runs: + using: 'composite' + steps: + - name: Determine branch name + id: branch-name + run: | + if [ "${{ github.event_name }}" = "delete" ]; then + # For delete events, event.ref contains the full ref path + BRANCH_NAME="${{ github.event.ref }}" + # Remove refs/heads/ or refs/tags/ prefix if present + BRANCH_NAME="${BRANCH_NAME#refs/heads/}" + BRANCH_NAME="${BRANCH_NAME#refs/tags/}" + elif [ "${{ github.event_name }}" = "pull_request" ]; then + # For PR events, use the head branch name (source branch of the PR) + BRANCH_NAME="${{ github.event.pull_request.head.ref }}" + elif [ "${{ github.event_name }}" = "push" ]; then + # For push events, ref_name is the branch name without prefix + BRANCH_NAME="${{ github.ref_name }}" + else + # Fallback: try ref_name which works for most event types + BRANCH_NAME="${{ github.ref_name }}" + fi + + # Validate that we got a branch name + if [ -z "$BRANCH_NAME" ] || [ "$BRANCH_NAME" = "" ]; then + echo "ERROR: Could not determine branch name from event" >&2 + echo "Event: ${{ github.event_name }}" >&2 + if [ "${{ github.event_name }}" = "delete" ]; then + echo "Ref: ${{ github.event.ref }}" >&2 + elif [ "${{ github.event_name }}" = "pull_request" ]; then + echo "PR head ref: ${{ github.event.pull_request.head.ref }}" >&2 + else + echo "Ref name: ${{ github.ref_name }}" >&2 + fi + exit 1 + fi + + echo "Branch name determined: $BRANCH_NAME" + echo "value=$BRANCH_NAME" >> $GITHUB_OUTPUT + shell: bash + diff --git a/.github/actions/post-coverage/action.yml b/.github/actions/post-coverage/action.yml new file mode 100644 index 00000000..4921f2ba --- /dev/null +++ b/.github/actions/post-coverage/action.yml @@ -0,0 +1,77 @@ +name: 'Post Coverage Comment' +description: 'Download coverage artifacts and post coverage comment on PR' +inputs: + pytest-xml-coverage-path: + description: 'Path to pytest XML coverage file' + required: false + default: 'coverage/coverage.xml' + junitxml-path: + description: 'Path to pytest JUnit XML file' + required: false + default: 'coverage/pytest.xml' + +runs: + using: 'composite' + steps: + - name: Download coverage artifacts + uses: actions/download-artifact@v4 + with: + name: coverage-py3.13 + path: ./ + + - name: Verify and prepare coverage files + shell: bash + run: | + echo "=== Files in coverage directory ===" + ls -lah coverage/ || echo "coverage/ directory doesn't exist" + echo "" + echo "=== Searching for coverage files ===" + find . -name "coverage.xml" -o -name "pytest.xml" | head -10 + echo "" + + # Create coverage directory if it doesn't exist + mkdir -p coverage + + # Find and move coverage.xml + COV_FILE=$(find . -name "coverage.xml" -type f | head -1) + if [ -n "$COV_FILE" ] && [ -f "$COV_FILE" ]; then + if [ "$COV_FILE" != "coverage/coverage.xml" ]; then + echo "Found coverage.xml at $COV_FILE, copying to coverage/coverage.xml" + cp "$COV_FILE" coverage/coverage.xml + fi + echo "✓ coverage.xml found and prepared" + else + echo "✗ ERROR: coverage.xml not found anywhere in the artifact" + echo "Artifact structure:" + find . -type f | head -20 + exit 1 + fi + + # Find and move pytest.xml + PYTEST_FILE=$(find . -name "pytest.xml" -type f | head -1) + if [ -n "$PYTEST_FILE" ] && [ -f "$PYTEST_FILE" ]; then + if [ "$PYTEST_FILE" != "coverage/pytest.xml" ]; then + echo "Found pytest.xml at $PYTEST_FILE, copying to coverage/pytest.xml" + cp "$PYTEST_FILE" coverage/pytest.xml + fi + echo "✓ pytest.xml found and prepared" + else + echo "⚠ WARNING: pytest.xml not found (may be optional)" + fi + + echo "" + echo "=== Final coverage directory contents ===" + ls -lah coverage/ + + # Verify coverage.xml exists before proceeding + if [ ! -f "coverage/coverage.xml" ]; then + echo "ERROR: coverage.xml is required but not found after preparation" + exit 1 + fi + + - name: Pytest coverage comment + uses: MishaKav/pytest-coverage-comment@main + with: + pytest-xml-coverage-path: ${{ inputs.pytest-xml-coverage-path }} + junitxml-path: ${{ inputs.junitxml-path }} + diff --git a/.github/actions/setup-python-uv/action.yml b/.github/actions/setup-python-uv/action.yml new file mode 100644 index 00000000..f5b8f45c --- /dev/null +++ b/.github/actions/setup-python-uv/action.yml @@ -0,0 +1,28 @@ +name: 'Setup Python with UV' +description: 'Common setup for Python, UV, and package caching' +inputs: + python-version: + description: 'Python version to use' + required: true + +runs: + using: 'composite' + steps: + - name: Set up Python ${{ inputs.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Cache UV packages + uses: actions/cache@v4 + with: + path: ~/.cache/uv + key: uv-packages-${{ runner.os }}-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + uv-packages-${{ runner.os }}- + diff --git a/.github/workflows/clean-docs.yml b/.github/workflows/clean-docs.yml index 619a324c..7d02ebd7 100644 --- a/.github/workflows/clean-docs.yml +++ b/.github/workflows/clean-docs.yml @@ -1,33 +1,70 @@ -name: Clean Docs for Deleted References +name: Clean Documentation Versions + on: delete: + # Triggers when a branch or tag is deleted + pull_request: + types: [closed] + # Triggers when a PR is closed (merged or unmerged) + +concurrency: + # Group by workflow and event to avoid duplicate runs for the same event + group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event.ref || github.event.pull_request.number || github.sha }} + cancel-in-progress: false jobs: - build: + delete-docs: + name: Delete documentation version runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.8] + # Delete docs for any closed PR (merged or not) or deleted branches + # Closed PRs don't need docs since the branch is likely being deleted steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - name: Document branch deleting - run: echo ${{ github.ref_name }} - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install mike - run: | - python -m pip install --upgrade pip - pip install mike - - name: Configure Git user - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - - name: Delete defunct docs versions - run: | - echo "Deleting ${{ github.event.ref_name }} version from docs" - mike delete --rebase --push ${{ github.event.ref_name }} + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: "3.13" + + - name: Configure Git user + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + shell: bash + + - name: Install dependencies + run: | + uv pip install --system -e .[docs] + shell: bash + + - name: Get branch name + id: branch + uses: ./.github/actions/get-branch-name + + - name: Delete documentation version + continue-on-error: true + run: | + BRANCH_NAME="${{ steps.branch.outputs.branch-name }}" + echo "Event type: ${{ github.event_name }}" + echo "Deleting documentation version for branch: $BRANCH_NAME" + + # Skip deletion for main and develop branches + if [ "$BRANCH_NAME" = "main" ] || [ "$BRANCH_NAME" = "develop" ]; then + echo "Skipping deletion for protected branch: $BRANCH_NAME" + exit 0 + fi + + # Check if version exists before trying to delete + echo "Checking if documentation version exists..." + if mike list | grep -q "^$BRANCH_NAME$"; then + echo "Version $BRANCH_NAME exists, attempting to delete..." + mike delete --rebase --push "$BRANCH_NAME" + echo "Successfully deleted documentation version: $BRANCH_NAME" + else + echo "Version $BRANCH_NAME does not exist in documentation, nothing to delete" + fi + shell: bash diff --git a/.github/workflows/prepare-release.yml b/.github/workflows/prepare-release.yml new file mode 100644 index 00000000..8fc72605 --- /dev/null +++ b/.github/workflows/prepare-release.yml @@ -0,0 +1,88 @@ +name: Prepare Release + +on: + # Run on release creation to validate and test + release: + types: [created] + # Manual trigger for testing + workflow_dispatch: + +jobs: + prepare-release: + name: Prepare and Validate Release + runs-on: ubuntu-latest + environment: + name: testpypi + url: https://test.pypi.org/p/network-wrangler + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Validate version matches tag + id: version-check + run: | + TAG="${GITHUB_REF_NAME}" # e.g. v1.2.3 + if [[ ! "$TAG" =~ ^v[0-9]+\.[0-9]+(\.[0-9]+)?([a-zA-Z0-9\.\-]+)?$ ]]; then + echo "Release tag must look like v1.2.3 or v1.0-beta.5 (optionally with pre-release suffix). Got: $TAG" + exit 1 + fi + VERSION="${TAG#v}" + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + PKG_VERSION=$(grep -m1 '__version__' network_wrangler/__init__.py | sed 's/.*"\(.*\)".*/\1/') + if [ "$PKG_VERSION" != "$VERSION" ]; then + echo "ERROR: tag version $VERSION != network_wrangler.__version__ $PKG_VERSION" + exit 1 + fi + echo "OK: network_wrangler.__version__ matches tag ($VERSION)" + + - name: Check if version already exists on TestPyPI + run: | + VERSION="${{ steps.version-check.outputs.version }}" + RESPONSE=$(curl -s "https://test.pypi.org/pypi/network-wrangler/$VERSION/json" || echo "not found") + if echo "$RESPONSE" | grep -q '"info"'; then + echo "WARNING: Version $VERSION already exists on TestPyPI" + echo "This is OK for prepare-release, but will be skipped during publish" + else + echo "Version $VERSION does not exist on TestPyPI" + fi + + - name: Check if version already exists on PyPI + run: | + VERSION="${{ steps.version-check.outputs.version }}" + RESPONSE=$(curl -s "https://pypi.org/pypi/network-wrangler/$VERSION/json" || echo "not found") + if echo "$RESPONSE" | grep -q '"info"'; then + echo "ERROR: Version $VERSION already exists on PyPI" + echo "Cannot prepare release - version already published" + exit 1 + else + echo "Version $VERSION does not exist on PyPI" + fi + + - name: Install build tools + run: | + python3 -m pip install --upgrade pip + pip install build --user + + - name: Build distribution + run: python3 -m build + + - name: Publish to TestPyPI for validation + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ + skip-existing: true + verbose: true + + - name: Test installation from TestPyPI + run: | + python3 -m pip install --upgrade pip + pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ network-wrangler==${{ steps.version-check.outputs.version }} + python -c "import network_wrangler; print(f'Successfully installed network-wrangler {network_wrangler.__version__}')" + diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml deleted file mode 100644 index d3931ddb..00000000 --- a/.github/workflows/prerelease.yml +++ /dev/null @@ -1,53 +0,0 @@ -# This workflow will build and upload a Python Package to TestPyPI -# https://github.com/marketplace/actions/pypi-publish -name: Test Building + Publishing to TestPyPI -on: - push: - branches: [main, develop] - workflow_dispatch: # Manual trigger -jobs: - build: - name: Build distribution 📦 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.x" - - name: Install pypa/build - run: >- - python3 -m - pip install - build - --user - - name: Build a binary wheel and a source tarball - run: python3 -m build - - name: Store the distribution packages - uses: actions/upload-artifact@v4 - with: - name: python-package-distributions - path: dist/ - publish-to-testpypi: - name: Publish Python 🐍 distribution 📦 to TestPyPI - needs: - - build - runs-on: ubuntu-latest - environment: - name: testpypi - url: https://test.pypi.org/p/network-wrangler - permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - steps: - - name: Download all the dists - uses: actions/download-artifact@v4 - with: - name: python-package-distributions - path: dist/ - - name: Publish distribution 📦 to TestPyPI - if: steps.compare-versions.outputs.publish == 'true' - uses: pypa/gh-action-pypi-publish@release/v1 - with: - repository-url: https://test.pypi.org/legacy/ - skip-existing: true - verbose: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..b52368d7 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,98 @@ +name: Publish Release + +on: + # Publish to PyPI on release publish + release: + types: [published] + # Manual trigger + workflow_dispatch: + inputs: + test: + description: 'Publish to TestPyPI (true) or PyPI (false)' + required: true + type: boolean + default: false + +jobs: + publish-to-pypi: + name: Publish to PyPI + if: | + github.event_name == 'release' || + (github.event_name == 'workflow_dispatch' && github.event.inputs.test == 'false') + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/network-wrangler + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.x" + + - name: Validate version matches tag + id: version-check + if: github.event_name == 'release' + run: | + TAG="${GITHUB_REF_NAME}" # e.g. v1.2.3 + if [[ ! "$TAG" =~ ^v[0-9]+\.[0-9]+(\.[0-9]+)?([a-zA-Z0-9\.\-]+)?$ ]]; then + echo "Release tag must look like v1.2.3 or v1.0-beta.5 (optionally with pre-release suffix). Got: $TAG" + exit 1 + fi + VERSION="${TAG#v}" + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + PKG_VERSION=$(grep -m1 '__version__' network_wrangler/__init__.py | sed 's/.*"\(.*\)".*/\1/') + if [ "$PKG_VERSION" != "$VERSION" ]; then + echo "ERROR: tag version $VERSION != network_wrangler.__version__ $PKG_VERSION" + exit 1 + fi + echo "OK: network_wrangler.__version__ matches tag ($VERSION)" + + - name: Check if version already exists on PyPI + id: check-version + run: | + VERSION="${{ steps.version-check.outputs.version }}" + if [ -z "$VERSION" ]; then + # For manual dispatch, skip version check + echo "exists=false" >> "$GITHUB_OUTPUT" + exit 0 + fi + + # Check if version exists on PyPI + RESPONSE=$(curl -s "https://pypi.org/pypi/network-wrangler/$VERSION/json" || echo "not found") + if echo "$RESPONSE" | grep -q '"info"'; then + echo "ERROR: Version $VERSION already exists on PyPI" + echo "exists=true" >> "$GITHUB_OUTPUT" + exit 1 + else + echo "Version $VERSION does not exist on PyPI" + echo "exists=false" >> "$GITHUB_OUTPUT" + fi + + - name: Install build tools + run: | + python3 -m pip install --upgrade pip + pip install build --user + + - name: Build distribution + run: python3 -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + publish-docs: + name: Publish documentation to GitHub Pages + if: github.event_name == 'release' + needs: + - publish-to-pypi + runs-on: ubuntu-latest + steps: + - name: Build and deploy docs + uses: ./.github/actions/build-docs + with: + python-version: "3.13" + branch-name: ${{ github.ref_name }} diff --git a/.github/workflows/pullrequest.yml b/.github/workflows/pullrequest.yml new file mode 100644 index 00000000..e8110edb --- /dev/null +++ b/.github/workflows/pullrequest.yml @@ -0,0 +1,256 @@ +name: PR Checks + +on: + pull_request: + types: [opened, synchronize, reopened] + # Only run when code changes: + # - opened: when PR is first created + # - synchronize: when new commits are pushed to PR branch + # - reopened: when a closed PR is reopened (may have new code) + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint code + runs-on: ubuntu-latest + permissions: + contents: write # Required to commit fixes back to PR + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: "3.13" + + - name: Install ruff + run: | + uv pip install --system ruff + shell: bash + + - name: Format code with ruff + run: ruff format network_wrangler + continue-on-error: true + + - name: Fix linting issues + run: ruff check --fix --output-format=github network_wrangler + continue-on-error: false + + - name: Check for changes + id: check-changes + run: | + if [ -n "$(git status --porcelain)" ]; then + echo "has_changes=true" >> $GITHUB_OUTPUT + else + echo "has_changes=false" >> $GITHUB_OUTPUT + fi + + - name: Commit fixes + if: steps.check-changes.outputs.has_changes == 'true' + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add -A + git commit -m "style: Auto-fix code formatting and linting issues" + git push origin HEAD:${{ github.event.pull_request.head.ref }} + shell: bash + + - name: Check for remaining issues + run: ruff check --output-format=github network_wrangler + + test: + name: Test (Python ${{ matrix.python-version }}) + needs: lint + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + uv pip install --system -e .[tests] + shell: bash + + - name: Run tests + run: | + # Only generate coverage for Python 3.13 to save time and simplify artifacts + if [ "${{ matrix.python-version }}" = "3.13" ]; then + pytest --junitxml=pytest.xml --cov=network_wrangler --cov-report "xml:coverage.xml" --benchmark-save=benchmark --benchmark-json=benchmark.json + else + pytest --benchmark-save=benchmark --benchmark-json=benchmark.json + fi + + - name: Verify coverage files exist + if: matrix.python-version == '3.13' + shell: bash + run: | + echo "=== Checking for coverage files ===" + if [ -f "coverage.xml" ]; then + echo "✓ coverage.xml exists ($(wc -c < coverage.xml) bytes)" + head -20 coverage.xml + else + echo "✗ coverage.xml NOT FOUND" + echo "Current directory: $(pwd)" + echo "Files in current directory:" + ls -lah | head -20 + exit 1 + fi + + if [ -f "pytest.xml" ]; then + echo "✓ pytest.xml exists ($(wc -c < pytest.xml) bytes)" + else + echo "⚠ pytest.xml NOT FOUND" + fi + + - name: Upload coverage artifacts + if: matrix.python-version == '3.13' + uses: actions/upload-artifact@v4 + with: + name: coverage-py3.13 + path: | + coverage.xml + pytest.xml + if-no-files-found: error + + - name: Check benchmark files exist + if: | + (github.event.pull_request.base.ref == 'main' || github.event.pull_request.base.ref == 'develop') && + matrix.python-version == '3.13' + shell: bash + run: | + echo "Checking for benchmark files..." + ls -lah benchmark.json 2>/dev/null || echo "benchmark.json not found" + ls -lah .benchmarks/ 2>/dev/null || echo ".benchmarks/ directory not found" + echo "PR base ref: ${{ github.event.pull_request.base.ref }}" + echo "Python version: ${{ matrix.python-version }}" + + - name: Upload benchmark artifacts + if: | + (github.event.pull_request.base.ref == 'main' || github.event.pull_request.base.ref == 'develop') && + matrix.python-version == '3.13' + uses: actions/upload-artifact@v4 + with: + name: benchmark-py3.13 + path: | + benchmark.json + .benchmarks/ + if-no-files-found: warn + + docs: + name: Build and deploy docs + needs: lint + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Get branch name + id: branch + uses: ./.github/actions/get-branch-name + + - name: Build and deploy docs + uses: ./.github/actions/build-docs + with: + python-version: "3.13" + branch-name: ${{ steps.branch.outputs.branch-name }} + + - name: Post docs comment + if: github.event.action == 'opened' + uses: actions/github-script@v7 + continue-on-error: true + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const branchName = '${{ steps.branch.outputs.branch-name }}'; + const docsUrl = `https://network-wrangler.github.io/network_wrangler/${branchName}/`; + const comment = `📚 Documentation has been built and deployed! + + View the docs for this PR: [${docsUrl}](${docsUrl})`; + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: ${{ github.event.pull_request.number }}, + body: comment, + }); + + benchmark: + name: Compare and store benchmarks + needs: test + if: github.event.pull_request.base.ref == 'main' || github.event.pull_request.base.ref == 'develop' + runs-on: ubuntu-latest + permissions: + contents: write # Required to commit benchmark.json to PR branch + pull-requests: write # Required for posting PR comments (PRs are issues in GitHub API) + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history to fetch base branch + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Download benchmark artifacts + uses: actions/download-artifact@v4 + with: + name: benchmark-py3.13 + path: ./ + + - name: Check if benchmark files exist + shell: bash + run: | + if [ ! -f "benchmark.json" ]; then + echo "WARNING: benchmark.json not found. Artifact may not have been uploaded." + echo "This can happen if:" + echo " 1. Tests didn't run benchmarks" + echo " 2. Benchmark files weren't created" + echo " 3. Upload step was skipped" + exit 1 + fi + echo "✓ benchmark.json found" + + - name: Compare and commit benchmarks + uses: ./.github/actions/compare-benchmarks + with: + comparison-type: pr + benchmark-json-path: benchmark.json + base-branch: ${{ github.event.pull_request.base.ref }} + github-token: ${{ secrets.GITHUB_TOKEN }} + pr-number: ${{ github.event.pull_request.number }} + python-version: "3.13" + + coverage: + name: Post coverage comment + needs: test + if: github.event.pull_request.base.ref == 'main' || github.event.pull_request.base.ref == 'develop' + runs-on: ubuntu-latest + permissions: + pull-requests: write # Required for posting PR comments (PRs are issues in GitHub API) + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Post coverage comment + uses: ./.github/actions/post-coverage diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 1eb88c3a..69267001 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -1,57 +1,169 @@ -name: Lint, test + build docs +name: CI (Lint, Test, Docs, Benchmark) -on: [push] +on: + push: + branches: [main, develop] jobs: - tests: + lint: + name: Lint code + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: "3.13" + + - name: Install ruff + run: | + uv pip install --system ruff + shell: bash + - name: Check formatting + run: ruff format --check network_wrangler + continue-on-error: true + + - name: Check linting + run: ruff check --output-format=github network_wrangler + continue-on-error: true + + test: + name: Test (Python ${{ matrix.python-version }}) + needs: lint runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Document branch - run: echo ${{ github.ref_name }} - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Configure Git user - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e .[tests] - - name: Lint - run: ruff check --output-format=github network_wrangler - - name: Run tests - run: | - pytest --junitxml=pytest.xml --cov-report "xml:coverage.xml" --benchmark-save=benchmark --benchmark-json=benchmark.json - - name: Build docs - run: | - mike deploy --push ${{ github.ref_name }} - - name: Update latest docs - if: github.ref == 'refs/heads/main' - run: | - mike alias ${{ github.ref_name }} latest --update-aliases --push - - name: Store benchmark result - uses: benchmark-action/github-action-benchmark@v1 - with: - tool: 'pytest' - output-file-path: benchmark.json - alert-threshold: '125%' - github-token: ${{ secrets.GITHUB_TOKEN }} - comment-on-alert: true - summary-always: true - - name: Pytest coverage comment - if: github.event_name == 'pull_request' - uses: MishaKav/pytest-coverage-comment@main - with: - pytest-xml-coverage-path: coverage.xml - junitxml-path: pytest.xml + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python with UV + uses: ./.github/actions/setup-python-uv + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + uv pip install --system -e .[tests] + shell: bash + + - name: Run tests + run: | + # Only generate coverage for Python 3.13 to save time and simplify artifacts + if [ "${{ matrix.python-version }}" = "3.13" ]; then + pytest --junitxml=pytest.xml --cov=network_wrangler --cov-report "xml:coverage.xml" --benchmark-save=benchmark --benchmark-json=benchmark.json + else + pytest --benchmark-save=benchmark --benchmark-json=benchmark.json + fi + + - name: Verify coverage files exist + if: matrix.python-version == '3.13' + shell: bash + run: | + echo "=== Checking for coverage files ===" + if [ -f "coverage.xml" ]; then + echo "✓ coverage.xml exists ($(wc -c < coverage.xml) bytes)" + head -20 coverage.xml + else + echo "✗ coverage.xml NOT FOUND" + echo "Current directory: $(pwd)" + echo "Files in current directory:" + ls -lah | head -20 + exit 1 + fi + + if [ -f "pytest.xml" ]; then + echo "✓ pytest.xml exists ($(wc -c < pytest.xml) bytes)" + else + echo "⚠ pytest.xml NOT FOUND" + fi + + - name: Upload coverage artifacts + if: matrix.python-version == '3.13' + uses: actions/upload-artifact@v4 + with: + name: coverage-py3.13 + path: | + coverage.xml + pytest.xml + if-no-files-found: error + + - name: Upload benchmark artifacts + if: matrix.python-version == '3.13' + uses: actions/upload-artifact@v4 + with: + name: benchmark-py3.13 + path: | + benchmark.json + .benchmarks/ + if-no-files-found: warn + + benchmark: + name: Compare and store benchmarks + needs: test + runs-on: ubuntu-latest + permissions: + contents: write # Required to commit benchmark.json to branch + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Need full history to compare to previous commit + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Download benchmark artifacts + uses: actions/download-artifact@v4 + with: + name: benchmark-py3.13 + path: ./ + continue-on-error: true + + - name: Check if benchmark.json exists + shell: bash + run: | + if [ ! -f "benchmark.json" ]; then + echo "WARNING: benchmark.json not found. Skipping benchmark comparison." + echo "This can happen if:" + echo " 1. Tests didn't run benchmarks" + echo " 2. Benchmark files weren't created" + echo " 3. Upload step was skipped" + exit 0 + fi + echo "✓ benchmark.json found" + + - name: Compare and commit benchmarks + if: hashFiles('benchmark.json') != '' + uses: ./.github/actions/compare-benchmarks + with: + comparison-type: push + benchmark-json-path: benchmark.json + github-token: ${{ secrets.GITHUB_TOKEN }} + python-version: "3.13" + + docs: + name: Build and deploy docs + needs: lint + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Get branch name + id: branch + uses: ./.github/actions/get-branch-name + + - name: Build and deploy docs + uses: ./.github/actions/build-docs + with: + python-version: "3.13" + branch-name: ${{ steps.branch.outputs.branch-name }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index ce95f2b4..00000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,74 +0,0 @@ -# This workflow will build and upload a Python Package to PyPI -# https://github.com/marketplace/actions/pypi-publish -name: Build + Publish Python Package 📦 to PyPI -on: - release: - types: [created] - workflow_dispatch: # Manual trigger -jobs: - build: - name: Build distribution 📦 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.x" - - name: Install pypa/build - run: >- - python3 -m - pip install - build - --user - - name: Build a binary wheel and a source tarball - run: python3 -m build - - name: Store the distribution packages - uses: actions/upload-artifact@v4 - with: - name: python-package-distributions - path: dist/ - publish-to-pypi: - name: >- - Publish Python 🐍 distribution 📦 to PyPI - needs: - - build - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/network-wrangler - permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - steps: - - name: Download all the dists - uses: actions/download-artifact@v4 - with: - name: python-package-distributions - path: dist/ - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - publish-docs: - name: Publish documentation 📚 to GitHub Pages - needs: - - build - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Configure Git user - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e .[docs] - - name: Build docs - run: | - mike deploy --push --update-aliases latest ${{ github.ref_name }} diff --git a/.gitignore b/.gitignore index c8756716..d6b78599 100644 --- a/.gitignore +++ b/.gitignore @@ -60,6 +60,7 @@ pip-log.txt pip-delete-this-directory.txt # Virtual Envs +.*venv* .env .venv venv/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index de880c74..3d8fea52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.6.8 + # Ruff version - keep in sync with GitHub Actions + rev: v0.11.13 hooks: # Run the linter. - id: ruff diff --git a/CHANGELOG.md b/CHANGELOG.md index bcda9c9d..07da6132 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Notable changes and version history. | Version | Date | Comment | |---------|-------|-------| +| [v1.0-beta.4](https://github.com/network-wrangler/network_wrangler/releases/tag/v1.0-beta.4) | 2026-01-07 | Hotfix: Fix pandera/pandas compatibility issues, Python 3.9 TypeGuard compatibility, type guard bug fixes, mkdocs build compatibility, and remove unused noqa directives. | | [v1.0-beta-2](https://github.com/wsp-sag/network_wrangler/releases/tag/v1.0-beta-1) | 20204-10-15 | Bug fixes in scenario loading, projectcard API and compatibility of transit net with roadway deletions. Some additional performance improvements. | | [v1.0-beta-1](https://github.com/wsp-sag/network_wrangler/releases/tag/v1.0-beta-1) | 20204-10-9 | Feature-complete for 1.0 | | [v1.0-alpha-2](https://github.com/wsp-sag/network_wrangler/releases/tag/v1.0-alpha.2) | 2024-10-8 | Testing for Met Council | diff --git a/README.md b/README.md index 7c4a7781..11367cd6 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Network Wrangler is a Python library for managing travel model network scenarios Network Wrangler should be operating system agonistic and has been tested on Ubuntu and Mac OS. -Network Wrangler does require Python 3.7+. If you have a different version of Python installed (e.g. from ArcGIS), `conda` or a similar virtual environment manager can care of installing it for you in the installation instructions below. +Network Wrangler does require Python 3.10+. If you have a different version of Python installed (e.g. from ArcGIS), `conda` or a similar virtual environment manager can care of installing it for you in the installation instructions below. ## Installation diff --git a/docs/index.md b/docs/index.md index d96ca3f1..b2549946 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Network Wrangler is a Python library for managing travel model network scenarios Network Wrangler should be operating system agonistic and has been tested on Ubuntu and Mac OS. -Network Wrangler does require Python 3.9+. If you have a different version of Python installed (e.g. from ArcGIS), `conda` or a similar virtual environment manager can care of installing it for you in the installation instructions below. +Network Wrangler does require Python 3.10+. If you have a different version of Python installed (e.g. from ArcGIS), `conda` or a similar virtual environment manager can care of installing it for you in the installation instructions below. !!! tip "installing conda" diff --git a/docs/networks.md b/docs/networks.md index 60634a7f..a4d7f78c 100644 --- a/docs/networks.md +++ b/docs/networks.md @@ -98,7 +98,7 @@ A valid `geojson`, `shp`, or `parquet` file with `LineString` geometry features ::: network_wrangler.models.gtfs.tables options: show_bases: false - members: None + members: [] Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format with the following additional constraints: @@ -112,7 +112,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w options: heading_level: 3 show_bases: false - members: None + members: [] handlers: python: options: @@ -124,7 +124,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w options: heading_level: 3 show_bases: false - members: None + members: [] handlers: python: options: @@ -134,7 +134,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w ::: network_wrangler.models.gtfs.tables.WranglerTripsTable options: - members: None + members: [] heading_level: 3 show_bases: false handlers: @@ -146,7 +146,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w ::: network_wrangler.models.gtfs.tables.WranglerStopTimesTable options: - members: None + members: [] heading_level: 3 show_bases: false handlers: @@ -158,7 +158,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w ::: network_wrangler.models.gtfs.tables.WranglerShapesTable options: - members: None + members: [] heading_level: 3 show_bases: false handlers: @@ -170,7 +170,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w ::: network_wrangler.models.gtfs.tables.WranglerFrequenciesTable options: - members: None + members: [] heading_level: 3 show_bases: false handlers: @@ -184,7 +184,7 @@ Transit Networks must use the the [GTFS](https://www.gtfs.org) Schedule format w options: heading_level: 3 show_bases: false - members: None + members: [] handlers: python: options: diff --git a/mkdocs.yml b/mkdocs.yml index 3dfc5510..79368724 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,5 +1,5 @@ site_name: "Network Wrangler" -site_url: https://wsp-sag.github.io/network_wrangler +site_url: https://network-wrangler.github.io/network_wrangler repo_url: https://github.com/wsp-sag/network_wrangler theme: @@ -39,8 +39,6 @@ plugins: handlers: python: paths: [.] - selection: - new_path_syntax: true rendering: show_root_heading: true show_source: true diff --git a/network_wrangler/__init__.py b/network_wrangler/__init__.py index dd550f63..eaab64d2 100644 --- a/network_wrangler/__init__.py +++ b/network_wrangler/__init__.py @@ -1,6 +1,6 @@ """Network Wrangler Package.""" -__version__ = "1.0-beta.3" +__version__ = "1.0-beta.4" import warnings diff --git a/network_wrangler/bin/build_basic_osm_roadnet.py b/network_wrangler/bin/build_basic_osm_roadnet.py index d3066572..fd1c2c3d 100755 --- a/network_wrangler/bin/build_basic_osm_roadnet.py +++ b/network_wrangler/bin/build_basic_osm_roadnet.py @@ -9,9 +9,9 @@ Arguments: place_name (str): Name of the place to build the road network for. - --type (Optional[str]): Type of network to build Defaults to `drive`. - --path (Optional[str]): Path to write the network. Defaults to current working directory. - --file_format (Optional[str]): File format for writing the network. Defaults to `geojson`. + --type (str | None): Type of network to build Defaults to `drive`. + --path (str | None): Path to write the network. Defaults to current working directory. + --file_format (str | None): File format for writing the network. Defaults to `geojson`. Example: ```bash diff --git a/network_wrangler/configs/__init__.py b/network_wrangler/configs/__init__.py index 9dcc7a0e..079b3467 100644 --- a/network_wrangler/configs/__init__.py +++ b/network_wrangler/configs/__init__.py @@ -1,17 +1,16 @@ """Configuration module for network_wrangler.""" from pathlib import Path -from typing import Optional, Union from ..logger import WranglerLogger from .scenario import ScenarioConfig from .utils import _config_data_from_files from .wrangler import DefaultConfig, WranglerConfig -ConfigInputTypes = Union[dict, Path, list[Path], WranglerConfig] +ConfigInputTypes = dict | Path | list[Path] | WranglerConfig -def load_wrangler_config(data: Optional[ConfigInputTypes] = None) -> WranglerConfig: +def load_wrangler_config(data: ConfigInputTypes | None = None) -> WranglerConfig: """Load the WranglerConfiguration.""" if isinstance(data, WranglerConfig): return data @@ -29,7 +28,7 @@ def load_wrangler_config(data: Optional[ConfigInputTypes] = None) -> WranglerCon def load_scenario_config( - data: Optional[Union[ScenarioConfig, Path, list[Path], dict]] = None, + data: ScenarioConfig | Path | list[Path] | dict | None = None, ) -> ScenarioConfig: """Load the WranglerConfiguration.""" if isinstance(data, ScenarioConfig): diff --git a/network_wrangler/configs/scenario.py b/network_wrangler/configs/scenario.py index 2755e006..f93794f2 100644 --- a/network_wrangler/configs/scenario.py +++ b/network_wrangler/configs/scenario.py @@ -79,16 +79,15 @@ from datetime import datetime from pathlib import Path -from typing import Optional, Union from projectcard.io import _resolve_rel_paths from ..models._base.types import RoadwayFileTypes, TransitFileTypes from .utils import ConfigItem -from .wrangler import DefaultConfig, WranglerConfig +from .wrangler import DefaultConfig -ProjectCardFilepath = Union[Path, str] -ProjectCardFilepaths = Union[Path, list[Path], str, list[str]] +ProjectCardFilepath = Path | str +ProjectCardFilepaths = Path | list[Path] | str | list[str] DEFAULT_SCENARIO_NAME: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -153,8 +152,8 @@ def __init__( dir: Path = DEFAULT_ROADWAY_IN_DIR, file_format: RoadwayFileTypes = DEFAULT_ROADWAY_IN_FORMAT, read_in_shapes: bool = DEFAULT_ROADWAY_SHAPE_READ, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, ): """Constructor for RoadwayNetworkInputConfig.""" if dir is not None and not Path(dir).is_absolute(): @@ -184,7 +183,7 @@ def __init__( out_dir: Path = DEFAULT_ROADWAY_OUT_DIR, base_path: Path = DEFAULT_BASE_DIR, convert_complex_link_properties_to_single_field: bool = False, - prefix: Optional[str] = None, + prefix: str | None = None, file_format: RoadwayFileTypes = DEFAULT_ROADWAY_OUT_FORMAT, true_shape: bool = False, write: bool = DEFAULT_ROADWAY_WRITE, @@ -242,7 +241,7 @@ def __init__( self, base_path: Path = DEFAULT_BASE_DIR, out_dir: Path = DEFAULT_TRANSIT_OUT_DIR, - prefix: Optional[str] = None, + prefix: str | None = None, file_format: TransitFileTypes = DEFAULT_TRANSIT_OUT_FORMAT, write: bool = DEFAULT_TRANSIT_WRITE, ): @@ -291,21 +290,21 @@ class ScenarioInputConfig(ConfigItem): def __init__( self, base_path: Path = DEFAULT_BASE_DIR, - roadway: Optional[dict] = None, - transit: Optional[dict] = None, - applied_projects: Optional[list[str]] = None, - conflicts: Optional[dict] = None, + roadway: dict | None = None, + transit: dict | None = None, + applied_projects: list[str] | None = None, + conflicts: dict | None = None, ): """Constructor for ScenarioInputConfig.""" if roadway is not None: - self.roadway: Optional[RoadwayNetworkInputConfig] = RoadwayNetworkInputConfig( + self.roadway: RoadwayNetworkInputConfig | None = RoadwayNetworkInputConfig( **roadway, base_path=base_path ) else: self.roadway = None if transit is not None: - self.transit: Optional[TransitNetworkInputConfig] = TransitNetworkInputConfig( + self.transit: TransitNetworkInputConfig | None = TransitNetworkInputConfig( **transit, base_path=base_path ) else: @@ -329,9 +328,9 @@ def __init__( self, path: Path = DEFAULT_OUTPUT_DIR, base_path: Path = DEFAULT_BASE_DIR, - roadway: Optional[dict] = None, - transit: Optional[dict] = None, - project_cards: Optional[dict] = None, + roadway: dict | None = None, + transit: dict | None = None, + project_cards: dict | None = None, overwrite: bool = True, ): """Constructor for ScenarioOutputConfig.""" @@ -346,7 +345,7 @@ def __init__( self.transit = TransitNetworkOutputConfig(**transit, base_path=self.path) if project_cards is not None: - self.project_cards: Optional[ProjectCardOutputConfig] = ProjectCardOutputConfig( + self.project_cards: ProjectCardOutputConfig | None = ProjectCardOutputConfig( **project_cards, base_path=self.path ) else: diff --git a/network_wrangler/configs/utils.py b/network_wrangler/configs/utils.py index 8a1b0c50..9606c7c7 100644 --- a/network_wrangler/configs/utils.py +++ b/network_wrangler/configs/utils.py @@ -1,7 +1,6 @@ """Configuration utilities.""" from pathlib import Path -from typing import Optional, Union from pydantic import ValidationError @@ -21,7 +20,7 @@ class ConfigItem: Do not use "get" "to_dict", or "items" for key names. """ - base_path: Optional[Path] = None + base_path: Path | None = None def __getitem__(self, key): """Return the value for key if key is in the dictionary, else default.""" @@ -45,7 +44,7 @@ def get(self, key, default=None): """Return the value for key if key is in the dictionary, else default.""" return self.__dict__.get(key, default) - def update(self, data: Union[Path, list[Path], dict]): + def update(self, data: Path | list[Path] | dict): """Update the configuration with a dictionary of new values.""" if not isinstance(data, dict): WranglerLogger.info(f"Updating configuration with {data}.") @@ -65,7 +64,7 @@ def resolve_paths(self, base_path): setattr(self, key, str(resolved_path)) -def find_configs_in_dir(dir: Union[Path, list[Path]], config_type) -> list[Path]: +def find_configs_in_dir(dir: Path | list[Path], config_type) -> list[Path]: """Find configuration files in the directory that match `*config`.""" config_files: list[Path] = [] if isinstance(dir, list): @@ -88,7 +87,7 @@ def find_configs_in_dir(dir: Union[Path, list[Path]], config_type) -> list[Path] return [] -def _config_data_from_files(path: Optional[Union[Path, list[Path]]] = None) -> Union[None, dict]: +def _config_data_from_files(path: Path | list[Path] | None = None) -> None | dict: """Load and combine configuration data from file(s). Args: diff --git a/network_wrangler/configs/wrangler.py b/network_wrangler/configs/wrangler.py index 5e4f9a51..bdf4df8d 100644 --- a/network_wrangler/configs/wrangler.py +++ b/network_wrangler/configs/wrangler.py @@ -218,10 +218,10 @@ class WranglerConfig(ConfigItem): EDITS: Parameters governing how edits are handled. """ - IDS: IdGenerationConfig = IdGenerationConfig() - MODEL_ROADWAY: ModelRoadwayConfig = ModelRoadwayConfig() - CPU: CpuConfig = CpuConfig() - EDITS: EditsConfig = EditsConfig() + IDS: IdGenerationConfig = Field(default_factory=IdGenerationConfig) + MODEL_ROADWAY: ModelRoadwayConfig = Field(default_factory=ModelRoadwayConfig) + CPU: CpuConfig = Field(default_factory=CpuConfig) + EDITS: EditsConfig = Field(default_factory=EditsConfig) DefaultConfig = WranglerConfig() diff --git a/network_wrangler/logger.py b/network_wrangler/logger.py index b5e29170..6917d28b 100644 --- a/network_wrangler/logger.py +++ b/network_wrangler/logger.py @@ -1,18 +1,16 @@ """Logging utilities for Network Wrangler.""" import logging -import os import sys from datetime import datetime from pathlib import Path -from typing import Optional WranglerLogger = logging.getLogger("WranglerLogger") def setup_logging( - info_log_filename: Optional[Path] = None, - debug_log_filename: Optional[Path] = None, + info_log_filename: Path | None = None, + debug_log_filename: Path | None = None, std_out_level: str = "info", ): """Sets up the WranglerLogger w.r.t. the debug file location and if logging to console. diff --git a/network_wrangler/models/_base/db.py b/network_wrangler/models/_base/db.py index 34e6d518..e2280554 100644 --- a/network_wrangler/models/_base/db.py +++ b/network_wrangler/models/_base/db.py @@ -1,7 +1,8 @@ import copy import hashlib from collections import defaultdict -from typing import Callable, ClassVar, Optional +from collections.abc import Callable +from typing import ClassVar import pandas as pd from pandera import DataFrameModel @@ -83,6 +84,8 @@ class DBModelMixin: optional_table_names: list of optional table names that will be added to `table_names` iff they are found. hash: creates a hash of tables found in `table_names` to track if they change. + modification_version: counter that increments when tables are modified. Used for + efficient change detection without computing expensive hashes. tables: dataframes corresponding to each table_name in `table_names` tables_dict: mapping of `:` dataframe _table_models: mapping of `:` to use for validation when @@ -109,6 +112,27 @@ class DBModelMixin: # mapping of : to use iff df validation fails. _converters: ClassVar[dict[str, Callable]] = {} + # Instance attribute for tracking modifications (initialized in __setattr__) + _modification_version: int = 0 + + def _mark_modified(self) -> None: + """Mark the database as modified by incrementing the modification version. + + This is called automatically when tables are modified via __setattr__. + """ + # Use object.__setattr__ to avoid recursion + current = getattr(self, "_modification_version", 0) + object.__setattr__(self, "_modification_version", current + 1) + + @property + def modification_version(self) -> int: + """Return the current modification version. + + This counter increments each time a table is modified and can be used + for efficient change detection without computing expensive hashes. + """ + return getattr(self, "_modification_version", 0) + def __setattr__(self, key, value): """Override the default setattr behavior to handle DataFrame validation. @@ -126,6 +150,9 @@ def __setattr__(self, key, value): WranglerLogger.debug(f"Validating + coercing value to {key}") df = self.validate_coerce_table(key, value) super().__setattr__(key, df) + # Mark as modified when a table is updated + if key in self.table_names or key in self.optional_table_names: + self._mark_modified() else: super().__setattr__(key, value) @@ -203,7 +230,7 @@ def fields_as_fks(cls) -> DbForeignKeyUsage: return {k: dict(v) for k, v in pks_as_fks.items()} def check_referenced_fk( - self, pk_table_name: str, pk_field: str, pk_table: Optional[pd.DataFrame] = None + self, pk_table_name: str, pk_field: str, pk_table: pd.DataFrame | None = None ) -> bool: """True if table.field has the values referenced in any table referencing fields as fk. @@ -264,7 +291,7 @@ def check_referenced_fk( ) return all_valid - def check_referenced_fks(self, table_name: str, table: Optional[pd.DataFrame] = None) -> bool: + def check_referenced_fks(self, table_name: str, table: pd.DataFrame | None = None) -> bool: """True if this table has the values referenced in any table referencing fields as fk. For example. If routes.route_id is referenced in trips table, we need to check that @@ -281,7 +308,7 @@ def check_referenced_fks(self, table_name: str, table: Optional[pd.DataFrame] = return all_valid def check_table_fks( - self, table_name: str, table: Optional[pd.DataFrame] = None, raise_error: bool = True + self, table_name: str, table: pd.DataFrame | None = None, raise_error: bool = True ) -> bool: """Return True if the foreign key fields in table have valid references. @@ -385,7 +412,11 @@ def table_names_with_field(self, field: str) -> list[str]: @property def hash(self) -> str: - """A hash representing the contents of the tables in self.table_names.""" + """A hash representing the contents of the tables in self.table_names. + + Note: This is an expensive operation. For change detection, prefer using + modification_version which is much faster. + """ _table_hashes = [self.get_table(t).df_hash() for t in self.table_names] _value = str.encode("-".join(_table_hashes)) diff --git a/network_wrangler/models/_base/geo.py b/network_wrangler/models/_base/geo.py index 55f167e0..4ba3486b 100644 --- a/network_wrangler/models/_base/geo.py +++ b/network_wrangler/models/_base/geo.py @@ -1,9 +1,6 @@ -from typing import Union - from pydantic import ( RootModel, confloat, - conlist, field_validator, ) diff --git a/network_wrangler/models/_base/records.py b/network_wrangler/models/_base/records.py index f53a968d..e1eb9775 100644 --- a/network_wrangler/models/_base/records.py +++ b/network_wrangler/models/_base/records.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Union +from typing import Any, ClassVar from pydantic import BaseModel, ConfigDict, model_validator @@ -44,9 +44,7 @@ class RecordModel(BaseModel): _examples: ClassVar[list[Any]] = [] @staticmethod - def _check_field_exists( - all_of_fields: Union[str, list[str]], fields_present: list[str] - ) -> bool: + def _check_field_exists(all_of_fields: str | list[str], fields_present: list[str]) -> bool: if isinstance(all_of_fields, list): return all(f in fields_present for f in all_of_fields) return all_of_fields in fields_present diff --git a/network_wrangler/models/_base/series.py b/network_wrangler/models/_base/series.py index 3ff18126..a370d93f 100644 --- a/network_wrangler/models/_base/series.py +++ b/network_wrangler/models/_base/series.py @@ -1,10 +1,14 @@ import pandera as pa +from pandera.engines import pandas_engine """ Time strings in HH:MM or HH:MM:SS format up to 48 hours. """ +# Use pandas_engine.NpString instead of pa.String to avoid StringDtype compatibility +# issues with numpy.issubdtype in newer pandas versions (2.2+) +# NpString uses object dtype which is compatible with numpy TimeStrSeriesSchema = pa.SeriesSchema( - pa.String, + pandas_engine.NpString(), pa.Check.str_matches(r"^(?:[0-9]|[0-3][0-9]|4[0-7]):[0-5]\d(?::[0-5]\d)?$|^24:00(?::00)?$"), coerce=True, name=None, # Name is set to None to ignore the Series name diff --git a/network_wrangler/models/_base/tables.py b/network_wrangler/models/_base/tables.py index 9ec8140a..743b8642 100644 --- a/network_wrangler/models/_base/tables.py +++ b/network_wrangler/models/_base/tables.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Any import pandas as pd from pandera.extensions import register_check_method diff --git a/network_wrangler/models/_base/types.py b/network_wrangler/models/_base/types.py index f1b57074..81fcf176 100644 --- a/network_wrangler/models/_base/types.py +++ b/network_wrangler/models/_base/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import time -from typing import Any, Literal, TypeVar, Union +from typing import Any, Literal, TypeVar import pandas as pd @@ -16,9 +16,9 @@ ForcedStr = Any # For simplicity, since BeforeValidator is not used here -OneOf = list[list[Union[str, list[str]]]] +OneOf = list[list[str | list[str]]] ConflictsWith = list[list[str]] -AnyOf = list[list[Union[str, list[str]]]] +AnyOf = list[list[str | list[str]]] Latitude = float Longitude = float @@ -43,7 +43,7 @@ def validate_timespan_string(value: Any) -> list[str]: if not isinstance(item, str): msg = "TimespanString elements must be strings" raise ValueError(msg) - import re # noqa: PLC0415 + import re if not re.match(r"^(\d+):([0-5]\d)(:[0-5]\d)?$", item): msg = f"Invalid time format: {item}" @@ -52,4 +52,4 @@ def validate_timespan_string(value: Any) -> list[str]: TimespanString = list[str] -TimeType = Union[time, str, int] +TimeType = time | str | int diff --git a/network_wrangler/models/gtfs/converters.py b/network_wrangler/models/gtfs/converters.py index 633c2516..bf1af1e2 100644 --- a/network_wrangler/models/gtfs/converters.py +++ b/network_wrangler/models/gtfs/converters.py @@ -56,9 +56,7 @@ def convert_stops_to_wrangler_stops(stops_df: pd.DataFrame) -> pd.DataFrame: # if stop_id is an int, convert to string if stops_df["stop_id"].dtype == "int64": stops_df["stop_id"] = stops_df["stop_id"].astype(str) - gtfs_stop_id = ( - stops_df.groupby("model_node_id").stop_id.apply(lambda x: ",".join(x)).reset_index() - ) + gtfs_stop_id = stops_df.groupby("model_node_id").stop_id.agg(",".join).reset_index() wr_stops_df["gtfs_stop_id"] = gtfs_stop_id["stop_id"] wr_stops_df = wr_stops_df.rename(columns={"model_node_id": "stop_id"}) return wr_stops_df diff --git a/network_wrangler/models/gtfs/table_types.py b/network_wrangler/models/gtfs/table_types.py index 217d4c5f..806e53af 100644 --- a/network_wrangler/models/gtfs/table_types.py +++ b/network_wrangler/models/gtfs/table_types.py @@ -2,10 +2,8 @@ import re from collections.abc import Iterable -from typing import Optional, Union import pandas as pd -import pandera as pa from pandera.dtypes import DataType from pandera.engines import pandas_engine @@ -21,7 +19,7 @@ def check( self, pandera_dtype: DataType, data_container: pd.Series, - ) -> Union[bool, Iterable[bool]]: + ) -> bool | Iterable[bool]: """Check if the data is a valid HTTP URL.""" correct_type = super().check(pandera_dtype) if not correct_type: diff --git a/network_wrangler/models/gtfs/tables.py b/network_wrangler/models/gtfs/tables.py index 1d043af6..cb2e1d86 100644 --- a/network_wrangler/models/gtfs/tables.py +++ b/network_wrangler/models/gtfs/tables.py @@ -23,7 +23,7 @@ ``` """ -from typing import ClassVar, Optional +from typing import ClassVar import pandas as pd import pandera as pa @@ -31,7 +31,6 @@ from pandera import DataFrameModel, Field from pandera.typing import Category, Series -from ...logger import WranglerLogger from ...params import DEFAULT_TIMESPAN from ...utils.time import str_to_time, str_to_time_series from .._base.db import TableForeignKeys, TablePrimaryKeys @@ -90,22 +89,22 @@ class StopsTable(DataFrameModel): stop_id (str): The stop_id. Primary key. Required to be unique. stop_lat (float): The stop latitude. stop_lon (float): The stop longitude. - wheelchair_boarding (Optional[int]): The wheelchair boarding. - stop_code (Optional[str]): The stop code. - stop_name (Optional[str]): The stop name. - tts_stop_name (Optional[str]): The text-to-speech stop name. - stop_desc (Optional[str]): The stop description. - zone_id (Optional[str]): The zone id. - stop_url (Optional[str]): The stop URL. - location_type (Optional[LocationType]): The location type. Values can be: + wheelchair_boarding (int | None): The wheelchair boarding. + stop_code (str | None): The stop code. + stop_name (str | None): The stop name. + tts_stop_name (str | None): The text-to-speech stop name. + stop_desc (str | None): The stop description. + zone_id (str | None): The zone id. + stop_url (str | None): The stop URL. + location_type (LocationType | None): The location type. Values can be: - 0: stop platform - 1: station - 2: entrance/exit - 3: generic node - 4: boarding area Default of blank assumes a stop platform. - parent_station (Optional[str]): The `stop_id` of the parent station. - stop_timezone (Optional[str]): The stop timezone. + parent_station (str | None): The `stop_id` of the parent station. + stop_timezone (str | None): The stop timezone. """ stop_id: Series[str] = Field(coerce=True, nullable=False, unique=True) @@ -113,23 +112,23 @@ class StopsTable(DataFrameModel): stop_lon: Series[float] = Field(coerce=True, nullable=False, ge=-180, le=180) # Optional Fields - wheelchair_boarding: Optional[Series[Category]] = Field( + wheelchair_boarding: Series[Category] | None = Field( dtype_kwargs={"categories": WheelchairAccessible}, coerce=True, default=0 ) - stop_code: Optional[Series[str]] = Field(nullable=True, coerce=True) - stop_name: Optional[Series[str]] = Field(nullable=True, coerce=True) - tts_stop_name: Optional[Series[str]] = Field(nullable=True, coerce=True) - stop_desc: Optional[Series[str]] = Field(nullable=True, coerce=True) - zone_id: Optional[Series[str]] = Field(nullable=True, coerce=True) - stop_url: Optional[Series[str]] = Field(nullable=True, coerce=True) - location_type: Optional[Series[Category]] = Field( + stop_code: Series[str] | None = Field(nullable=True, coerce=True) + stop_name: Series[str] | None = Field(nullable=True, coerce=True) + tts_stop_name: Series[str] | None = Field(nullable=True, coerce=True) + stop_desc: Series[str] | None = Field(nullable=True, coerce=True) + zone_id: Series[str] | None = Field(nullable=True, coerce=True) + stop_url: Series[str] | None = Field(nullable=True, coerce=True) + location_type: Series[Category] | None = Field( dtype_kwargs={"categories": LocationType}, nullable=True, coerce=True, default=0, ) - parent_station: Optional[Series[str]] = Field(nullable=True, coerce=True) - stop_timezone: Optional[Series[str]] = Field(nullable=True, coerce=True) + parent_station: Series[str] | None = Field(nullable=True, coerce=True) + stop_timezone: Series[str] | None = Field(nullable=True, coerce=True) class Config: """Config for the StopsTable data model.""" @@ -149,23 +148,23 @@ class WranglerStopsTable(StopsTable): stop_id (int): The stop_id. Primary key. Required to be unique. **Wrangler assumes that this is a reference to a roadway node and as such must be an integer** stop_lat (float): The stop latitude. stop_lon (float): The stop longitude. - wheelchair_boarding (Optional[int]): The wheelchair boarding. - stop_code (Optional[str]): The stop code. - stop_name (Optional[str]): The stop name. - tts_stop_name (Optional[str]): The text-to-speech stop name. - stop_desc (Optional[str]): The stop description. - zone_id (Optional[str]): The zone id. - stop_url (Optional[str]): The stop URL. - location_type (Optional[LocationType]): The location type. Values can be: + wheelchair_boarding (int | None): The wheelchair boarding. + stop_code (str | None): The stop code. + stop_name (str | None): The stop name. + tts_stop_name (str | None): The text-to-speech stop name. + stop_desc (str | None): The stop description. + zone_id (str | None): The zone id. + stop_url (str | None): The stop URL. + location_type (LocationType | None): The location type. Values can be: - 0: stop platform - 1: station - 2: entrance/exit - 3: generic node - 4: boarding area Default of blank assumes a stop platform. - parent_station (Optional[int]): The `stop_id` of the parent station. **Since stop_id is an integer in Wrangler, this field is also an integer** - stop_timezone (Optional[str]): The stop timezone. - stop_id_GTFS (Optional[str]): The stop_id from the GTFS data. + parent_station (int | None): The `stop_id` of the parent station. **Since stop_id is an integer in Wrangler, this field is also an integer** + stop_timezone (str | None): The stop timezone. + stop_id_GTFS (str | None): The stop_id from the GTFS data. projects (str): A comma-separated string value for projects that have been applied to this stop. """ @@ -189,18 +188,18 @@ class RoutesTable(DataFrameModel): Attributes: route_id (str): The route_id. Primary key. Required to be unique. - route_short_name (Optional[str]): The route short name. - route_long_name (Optional[str]): The route long name. + route_short_name (str | None): The route short name. + route_long_name (str | None): The route long name. route_type (RouteType): The route type. Required. Values can be: - 0: Tram, Streetcar, Light rail - 1: Subway, Metro - 2: Rail - 3: Bus - agency_id (Optional[str]): The agency_id. Foreign key to agency_id in the agencies table. - route_desc (Optional[str]): The route description. - route_url (Optional[str]): The route URL. - route_color (Optional[str]): The route color. - route_text_color (Optional[str]): The route text color. + agency_id (str | None): The agency_id. Foreign key to agency_id in the agencies table. + route_desc (str | None): The route description. + route_url (str | None): The route URL. + route_color (str | None): The route color. + route_text_color (str | None): The route text color. """ route_id: Series[str] = Field(nullable=False, unique=True, coerce=True) @@ -211,11 +210,11 @@ class RoutesTable(DataFrameModel): ) # Optional Fields - agency_id: Optional[Series[str]] = Field(nullable=True, coerce=True) - route_desc: Optional[Series[str]] = Field(nullable=True, coerce=True) - route_url: Optional[Series[str]] = Field(nullable=True, coerce=True) - route_color: Optional[Series[str]] = Field(nullable=True, coerce=True) - route_text_color: Optional[Series[str]] = Field(nullable=True, coerce=True) + agency_id: Series[str] | None = Field(nullable=True, coerce=True) + route_desc: Series[str] | None = Field(nullable=True, coerce=True) + route_url: Series[str] | None = Field(nullable=True, coerce=True) + route_color: Series[str] | None = Field(nullable=True, coerce=True) + route_text_color: Series[str] | None = Field(nullable=True, coerce=True) class Config: """Config for the RoutesTable data model.""" @@ -236,7 +235,7 @@ class ShapesTable(DataFrameModel): shape_pt_lat (float): The shape point latitude. shape_pt_lon (float): The shape point longitude. shape_pt_sequence (int): The shape point sequence. - shape_dist_traveled (Optional[float]): The shape distance traveled. + shape_dist_traveled (float | None): The shape distance traveled. """ shape_id: Series[str] = Field(nullable=False, coerce=True) @@ -245,7 +244,7 @@ class ShapesTable(DataFrameModel): shape_pt_sequence: Series[int] = Field(coerce=True, nullable=False, ge=0) # Optional - shape_dist_traveled: Optional[Series[float]] = Field(coerce=True, nullable=True, ge=0) + shape_dist_traveled: Series[float] | None = Field(coerce=True, nullable=True, ge=0) class Config: """Config for the ShapesTable data model.""" @@ -267,7 +266,7 @@ class WranglerShapesTable(ShapesTable): shape_pt_lat (float): The shape point latitude. shape_pt_lon (float): The shape point longitude. shape_pt_sequence (int): The shape point sequence. - shape_dist_traveled (Optional[float]): The shape distance traveled. + shape_dist_traveled (float | None): The shape distance traveled. shape_model_node_id (int): The model_node_id of the shape point. Foreign key to the model_node_id in the nodes table. projects (str): A comma-separated string value for projects that have been applied to this shape. """ @@ -289,14 +288,14 @@ class TripsTable(DataFrameModel): - 1: Inbound service_id (str): The service id. route_id (str): The route id. Foreign key to `route_id` in the routes table. - trip_short_name (Optional[str]): The trip short name. - trip_headsign (Optional[str]): The trip headsign. - block_id (Optional[str]): The block id. - wheelchair_accessible (Optional[int]): The wheelchair accessible. Values can be: + trip_short_name (str | None): The trip short name. + trip_headsign (str | None): The trip headsign. + block_id (str | None): The block id. + wheelchair_accessible (int | None): The wheelchair accessible. Values can be: - 0: No information - 1: Allowed - 2: Not allowed - bikes_allowed (Optional[int]): The bikes allowed. Values can be: + bikes_allowed (int | None): The bikes allowed. Values can be: - 0: No information - 1: Allowed - 2: Not allowed @@ -311,13 +310,13 @@ class TripsTable(DataFrameModel): route_id: Series[str] = Field(nullable=False, coerce=True) # Optional Fields - trip_short_name: Optional[Series[str]] = Field(nullable=True, coerce=True) - trip_headsign: Optional[Series[str]] = Field(nullable=True, coerce=True) - block_id: Optional[Series[str]] = Field(nullable=True, coerce=True) - wheelchair_accessible: Optional[Series[Category]] = Field( + trip_short_name: Series[str] | None = Field(nullable=True, coerce=True) + trip_headsign: Series[str] | None = Field(nullable=True, coerce=True) + block_id: Series[str] | None = Field(nullable=True, coerce=True) + wheelchair_accessible: Series[Category] | None = Field( dtype_kwargs={"categories": WheelchairAccessible}, coerce=True, default=0 ) - bikes_allowed: Optional[Series[Category]] = Field( + bikes_allowed: Series[Category] | None = Field( dtype_kwargs={"categories": BikesAllowed}, coerce=True, default=0, @@ -345,14 +344,14 @@ class WranglerTripsTable(TripsTable): - 1: Inbound service_id (str): The service id. route_id (str): The route id. Foreign key to `route_id` in the routes table. - trip_short_name (Optional[str]): The trip short name. - trip_headsign (Optional[str]): The trip headsign. - block_id (Optional[str]): The block id. - wheelchair_accessible (Optional[int]): The wheelchair accessible. Values can be: + trip_short_name (str | None): The trip short name. + trip_headsign (str | None): The trip headsign. + block_id (str | None): The block id. + wheelchair_accessible (int | None): The wheelchair accessible. Values can be: - 0: No information - 1: Allowed - 2: Not allowed - bikes_allowed (Optional[int]): The bikes allowed. Values can be: + bikes_allowed (int | None): The bikes allowed. Values can be: - 0: No information - 1: Allowed - 2: Not allowed @@ -475,8 +474,8 @@ class StopTimesTable(DataFrameModel): - 3: Must coordinate with driver to arrange drop off arrival_time (TimeString): The arrival time in HH:MM:SS format. departure_time (TimeString): The departure time in HH:MM:SS format. - shape_dist_traveled (Optional[float]): The shape distance traveled. - timepoint (Optional[TimepointType]): The timepoint type. Values can be: + shape_dist_traveled (float | None): The shape distance traveled. + timepoint (TimepointType | None): The timepoint type. Values can be: - 0: The stop is not a timepoint - 1: The stop is a timepoint """ @@ -498,8 +497,8 @@ class StopTimesTable(DataFrameModel): departure_time: Series[pa.Timestamp] = Field(nullable=True, default=pd.NaT, coerce=True) # Optional - shape_dist_traveled: Optional[Series[float]] = Field(coerce=True, nullable=True, ge=0) - timepoint: Optional[Series[Category]] = Field( + shape_dist_traveled: Series[float] | None = Field(coerce=True, nullable=True, ge=0) + timepoint: Series[Category] | None = Field( dtype_kwargs={"categories": TimepointType}, coerce=True, default=0 ) @@ -549,8 +548,8 @@ class WranglerStopTimesTable(StopTimesTable): - 1: No drop off available - 2: Must phone agency to arrange drop off - 3: Must coordinate with driver to arrange drop off - shape_dist_traveled (Optional[float]): The shape distance traveled. - timepoint (Optional[TimepointType]): The timepoint type. Values can be: + shape_dist_traveled (float | None): The shape distance traveled. + timepoint (TimepointType | None): The timepoint type. Values can be: - 0: The stop is not a timepoint - 1: The stop is a timepoint projects (str): A comma-separated string value for projects that have been applied to this stop. diff --git a/network_wrangler/models/gtfs/types.py b/network_wrangler/models/gtfs/types.py index 5bda7ce1..31f4fbc9 100644 --- a/network_wrangler/models/gtfs/types.py +++ b/network_wrangler/models/gtfs/types.py @@ -1,11 +1,6 @@ """Field types for GTFS data.""" from enum import IntEnum -from typing import Annotated - -from pydantic import Field, HttpUrl - -from .._base.types import TimeString class BikesAllowed(IntEnum): diff --git a/network_wrangler/models/projects/roadway_changes.py b/network_wrangler/models/projects/roadway_changes.py index d13d4445..c8a1ee46 100644 --- a/network_wrangler/models/projects/roadway_changes.py +++ b/network_wrangler/models/projects/roadway_changes.py @@ -4,7 +4,7 @@ import itertools from datetime import datetime -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, ClassVar, Literal from pydantic import ( BaseModel, @@ -13,7 +13,6 @@ ValidationError, field_validator, model_validator, - validate_call, ) from ...errors import ScopeConflictError @@ -36,12 +35,12 @@ class IndivScopedPropertySetItem(BaseModel): model_config = ConfigDict(extra="forbid", exclude_none=True) - category: Optional[Union[str, int]] = DEFAULT_CATEGORY - timespan: Optional[TimespanString] = DEFAULT_TIMESPAN - set: Optional[Any] = None - existing: Optional[Any] = None - overwrite_conflicts: Optional[bool] = False - change: Optional[Union[int, float]] = None + category: str | int | None = DEFAULT_CATEGORY + timespan: TimespanString | None = DEFAULT_TIMESPAN + set: Any | None = None + existing: Any | None = None + overwrite_conflicts: bool | None = False + change: int | float | None = None _examples = [ {"category": "hov3", "timespan": ["6:00", "9:00"], "set": 2.0}, {"category": "hov2", "set": 2.0}, @@ -96,14 +95,14 @@ class GroupedScopedPropertySetItem(BaseModel): model_config = ConfigDict(extra="forbid", exclude_none=True) - category: Optional[Union[str, int]] = None - timespan: Optional[TimespanString] = None - categories: Optional[list[Any]] = [] - timespans: Optional[list[TimespanString]] = [] - set: Optional[Any] = None - overwrite_conflicts: Optional[bool] = False - existing: Optional[Any] = None - change: Optional[Union[int, float]] = None + category: str | int | None = None + timespan: TimespanString | None = None + categories: list[Any] | None = [] + timespans: list[TimespanString] | None = [] + set: Any | None = None + overwrite_conflicts: bool | None = False + existing: Any | None = None + change: int | float | None = None _examples = [ {"category": "hov3", "timespan": ["6:00", "9:00"], "set": 2.0}, {"category": "hov2", "set": 2.0}, @@ -151,7 +150,7 @@ def validate_timespans(cls, v): def _grouped_to_indiv_list_of_scopedpropsetitem( - scoped_prop_set_list: list[Union[GroupedScopedPropertySetItem, IndivScopedPropertySetItem]], + scoped_prop_set_list: list[GroupedScopedPropertySetItem | IndivScopedPropertySetItem], ) -> list[IndivScopedPropertySetItem]: """Converts a list of ScopedPropertySetItem to a list of IndivScopedPropertySetItem. @@ -241,12 +240,12 @@ class RoadPropertyChange(RecordModel): model_config = ConfigDict(extra="forbid", exclude_none=True) - existing: Optional[Any] = None - change: Optional[Union[int, float]] = None - set: Optional[Any] = None - scoped: Optional[Union[None, ScopedPropertySetList]] = None - overwrite_scoped: Optional[Literal["conflicting", "all", "error"]] = None - existing_value_conflict: Optional[Literal["error", "warn", "skip"]] = None + existing: Any | None = None + change: int | float | None = None + set: Any | None = None + scoped: None | ScopedPropertySetList = None + overwrite_scoped: Literal["conflicting", "all", "error"] | None = None + existing_value_conflict: Literal["error", "warn", "skip"] | None = None require_one_of: ClassVar[OneOf] = [ ["change", "set"], @@ -289,14 +288,14 @@ class RoadwayDeletion(RecordModel): require_any_of: ClassVar[AnyOf] = [["links", "nodes"]] model_config = ConfigDict(extra="forbid") - links: Optional[SelectLinksDict] = None - nodes: Optional[SelectNodesDict] = None + links: SelectLinksDict | None = None + nodes: SelectNodesDict | None = None clean_shapes: bool = False clean_nodes: bool = False @field_validator("links") @classmethod - def set_to_all_modes(cls, links: Optional[SelectLinksDict] = None): + def set_to_all_modes(cls, links: SelectLinksDict | None = None): """Set the search mode to 'any' if not specified explicitly.""" if links is not None and links.modes == DEFAULT_SEARCH_MODES: links.modes = DEFAULT_DELETE_MODES diff --git a/network_wrangler/models/projects/roadway_selection.py b/network_wrangler/models/projects/roadway_selection.py index 9a87e3f6..8234bf66 100644 --- a/network_wrangler/models/projects/roadway_selection.py +++ b/network_wrangler/models/projects/roadway_selection.py @@ -2,11 +2,10 @@ from __future__ import annotations -from typing import Annotated, ClassVar, Optional +from typing import Annotated, ClassVar from pydantic import ConfigDict, Field -from ...logger import WranglerLogger from ...params import DEFAULT_SEARCH_MODES from .._base.records import RecordModel from .._base.types import AnyOf, ConflictsWith, OneOf @@ -22,8 +21,8 @@ class SelectNodeDict(RecordModel): require_one_of: ClassVar[OneOf] = [["osm_node_id", "model_node_id"]] model_config = ConfigDict(extra="allow") - osm_node_id: Optional[str] = None - model_node_id: Optional[int] = None + osm_node_id: str | None = None + model_node_id: int | None = None _examples: ClassVar[list[dict]] = [{"osm_node_id": "12345"}, {"model_node_id": 67890}] @@ -34,10 +33,10 @@ class SelectNodesDict(RecordModel): require_any_of: ClassVar[AnyOf] = [["osm_node_id", "model_node_id"]] model_config = ConfigDict(extra="forbid") - all: Optional[bool] = False - osm_node_id: Annotated[Optional[list[str]], Field(None, min_length=1)] - model_node_id: Annotated[Optional[list[int]], Field(min_length=1)] - ignore_missing: Optional[bool] = True + all: bool | None = False + osm_node_id: Annotated[list[str] | None, Field(None, min_length=1)] + model_node_id: Annotated[list[int] | None, Field(min_length=1)] + ignore_missing: bool | None = True _examples: ClassVar[list[dict]] = [ {"osm_node_id": ["12345", "67890"], "model_node_id": [12345, 67890]}, @@ -73,13 +72,13 @@ class SelectLinksDict(RecordModel): model_config = ConfigDict(extra="allow") - all: Optional[bool] = False - name: Annotated[Optional[list[str]], Field(None, min_length=1)] - ref: Annotated[Optional[list[str]], Field(None, min_length=1)] - osm_link_id: Annotated[Optional[list[str]], Field(None, min_length=1)] - model_link_id: Annotated[Optional[list[int]], Field(None, min_length=1)] + all: bool | None = False + name: Annotated[list[str] | None, Field(None, min_length=1)] + ref: Annotated[list[str] | None, Field(None, min_length=1)] + osm_link_id: Annotated[list[str] | None, Field(None, min_length=1)] + model_link_id: Annotated[list[int] | None, Field(None, min_length=1)] modes: list[str] = DEFAULT_SEARCH_MODES - ignore_missing: Optional[bool] = True + ignore_missing: bool | None = True _examples: ClassVar[list[dict]] = [ {"name": ["Main St"], "modes": ["drive"]}, @@ -97,10 +96,10 @@ class SelectFacility(RecordModel): ] model_config = ConfigDict(extra="forbid") - links: Optional[SelectLinksDict] = None - nodes: Optional[SelectNodesDict] = None - from_: Annotated[Optional[SelectNodeDict], Field(None, alias="from")] - to: Optional[SelectNodeDict] = None + links: SelectLinksDict | None = None + nodes: SelectNodesDict | None = None + from_: Annotated[SelectNodeDict | None, Field(None, alias="from")] + to: SelectNodeDict | None = None _examples: ClassVar[list[dict]] = [ { diff --git a/network_wrangler/models/projects/transit_selection.py b/network_wrangler/models/projects/transit_selection.py index 77f501e4..4352744c 100644 --- a/network_wrangler/models/projects/transit_selection.py +++ b/network_wrangler/models/projects/transit_selection.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Annotated, ClassVar, Literal, Optional +from typing import Annotated, ClassVar, Literal from pydantic import ConfigDict, Field, field_validator @@ -15,12 +15,12 @@ class SelectTripProperties(RecordModel): """Selection properties for transit trips.""" - trip_id: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - shape_id: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - direction_id: Annotated[Optional[int], Field(None)] - service_id: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - route_id: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - trip_short_name: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] + trip_id: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + shape_id: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + direction_id: Annotated[int | None, Field(None)] + service_id: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + route_id: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + trip_short_name: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] model_config = ConfigDict( extra="allow", @@ -33,8 +33,8 @@ class SelectTripProperties(RecordModel): class TransitABNodesModel(RecordModel): """Single transit link model.""" - A: Optional[int] = None # model_node_id - B: Optional[int] = None # model_node_id + A: int | None = None # model_node_id + B: int | None = None # model_node_id model_config = ConfigDict( extra="forbid", @@ -51,9 +51,9 @@ class SelectTransitLinks(RecordModel): ["ab_nodes", "model_link_id"], ] - model_link_id: Annotated[Optional[list[int]], Field(min_length=1)] = None - ab_nodes: Annotated[Optional[list[TransitABNodesModel]], Field(min_length=1)] = None - require: Optional[SelectionRequire] = "any" + model_link_id: Annotated[list[int] | None, Field(min_length=1)] = None + ab_nodes: Annotated[list[TransitABNodesModel] | None, Field(min_length=1)] = None + require: SelectionRequire | None = "any" model_config = ConfigDict( extra="forbid", @@ -83,9 +83,9 @@ class SelectTransitNodes(RecordModel): ] ] - # gtfs_stop_id: Annotated[Optional[List[ForcedStr]], Field(None, min_length=1)] TODO Not implemented + # gtfs_stop_id: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] TODO Not implemented model_node_id: Annotated[list[int], Field(min_length=1)] - require: Optional[SelectionRequire] = "any" + require: SelectionRequire | None = "any" model_config = ConfigDict( extra="forbid", @@ -103,10 +103,10 @@ class SelectTransitNodes(RecordModel): class SelectRouteProperties(RecordModel): """Selection properties for transit routes.""" - route_short_name: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - route_long_name: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - agency_id: Annotated[Optional[list[ForcedStr]], Field(None, min_length=1)] - route_type: Annotated[Optional[list[int]], Field(None, min_length=1)] + route_short_name: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + route_long_name: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + agency_id: Annotated[list[ForcedStr] | None, Field(None, min_length=1)] + route_type: Annotated[list[int] | None, Field(None, min_length=1)] model_config = ConfigDict( extra="allow", @@ -119,11 +119,11 @@ class SelectRouteProperties(RecordModel): class SelectTransitTrips(RecordModel): """Selection properties for transit trips.""" - trip_properties: Optional[SelectTripProperties] = None - route_properties: Optional[SelectRouteProperties] = None - timespans: Annotated[Optional[list[TimespanString]], Field(None, min_length=1)] - nodes: Optional[SelectTransitNodes] = None - links: Optional[SelectTransitLinks] = None + trip_properties: SelectTripProperties | None = None + route_properties: SelectRouteProperties | None = None + timespans: Annotated[list[TimespanString] | None, Field(None, min_length=1)] + nodes: SelectTransitNodes | None = None + links: SelectTransitLinks | None = None model_config = ConfigDict( extra="forbid", diff --git a/network_wrangler/models/roadway/converters.py b/network_wrangler/models/roadway/converters.py index f64de666..b6827cd1 100644 --- a/network_wrangler/models/roadway/converters.py +++ b/network_wrangler/models/roadway/converters.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Optional +import operator from pandas import DataFrame, Series @@ -14,13 +14,13 @@ def translate_links_df_v0_to_v1( - links_df: DataFrame, complex_properties: Optional[list[str]] = None + links_df: DataFrame, complex_properties: list[str] | None = None ) -> DataFrame: """Translates a links dataframe from v0 to v1 format. Args: links_df (DataFrame): _description_ - complex_properties (Optional[list[str]], optional): List of complex properties to + complex_properties (list[str] | None, optional): List of complex properties to convert from v0 to v1 link data model. Defaults to None. If None, will detect complex properties. """ @@ -39,13 +39,13 @@ def translate_links_df_v0_to_v1( def translate_links_df_v1_to_v0( - links_df: DataFrame, complex_properties: Optional[list[str]] = None + links_df: DataFrame, complex_properties: list[str] | None = None ) -> DataFrame: """Translates a links dataframe from v1 to v0 format. Args: links_df (DataFrame): _description_ - complex_properties (Optional[list[str]], optional): List of complex properties to + complex_properties (list[str] | None, optional): List of complex properties to convert from v0 to v1 link data model. Defaults to None. If None, will detect complex properties. """ @@ -85,7 +85,7 @@ def _v0_to_v1_scoped_link_property_list(v0_item_list: list[dict]) -> list[dict]: Returns: list[dict]: in v1 format """ - import pprint # noqa: PLC0415 + import pprint v1_item_list = [] @@ -133,9 +133,11 @@ def _translate_scoped_link_property_v0_to_v1(links_df: DataFrame, prop: str) -> ) links_df.loc[complex_idx, f"sc_{prop}"] = links_df.loc[complex_idx, prop].apply( - lambda x: _v0_to_v1_scoped_link_property_list(x) + _v0_to_v1_scoped_link_property_list + ) + links_df.loc[complex_idx, prop] = links_df.loc[complex_idx, prop].apply( + operator.itemgetter("default") ) - links_df.loc[complex_idx, prop] = links_df.loc[complex_idx, prop].apply(lambda x: x["default"]) return links_df diff --git a/network_wrangler/models/roadway/tables.py b/network_wrangler/models/roadway/tables.py index 9945c9c0..700a6282 100644 --- a/network_wrangler/models/roadway/tables.py +++ b/network_wrangler/models/roadway/tables.py @@ -14,9 +14,8 @@ from __future__ import annotations import datetime as dt -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar -import numpy as np import pandas as pd import pandera as pa from pandas import Int64Dtype as Int64 @@ -24,8 +23,7 @@ from pandera.typing import Series from pandera.typing.geopandas import GeoSeries -from ...logger import WranglerLogger -from .._base.db import TableForeignKeys, TablePrimaryKeys +from .._base.db import TablePrimaryKeys from .._base.tables import validate_pyd from .types import ScopedLinkValueList @@ -70,38 +68,39 @@ class RoadLinksTable(DataFrameModel): - -1 indicates that there is a parallel managed lane derived from this link (model network). shape_id (str): Identifier referencing the primary key of the shapes table. Default is None. lanes (int): Default number of lanes on the link. Default is 1. - sc_lanes (Optional[list[dict]]: List of scoped link values for the number of lanes. Default is None. + sc_lanes (list[dict] | None): List of scoped link values for the number of lanes. Default is None. Example: `[{'timespan':['12:00':'15:00'], 'value': 3},{'timespan':['15:00':'19:00'], 'value': 2}]`. price (float): Default price to use the link. Default is 0. - sc_price (Optional[list[dict]]): List of scoped link values for the price. Default is None. + sc_price (list[dict] | None): List of scoped link values for the price. Default is None. Example: `[{'timespan':['15:00':'19:00'],'category': 'sov', 'value': 2.5}]`. - ref (Optional[str]): Reference numbers for link referring to a route or exit number per the + ref (str | None): Reference numbers for link referring to a route or exit number per the [OSM definition](https://wiki.openstreetmap.org/wiki/Key:ref). Default is None. - access (Optional[Any]): User-defined method to note access restrictions for the link. Default is None. - ML_projects (Optional[str]): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. + access (Any | None): User-defined method to note access restrictions for the link. Default is None. + ML_projects (str | None): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. Comma-separated list of project names applied to the managed lane. Default is "". - ML_lanes (Optional[int]): Default number of lanes on the managed lane. Default is None. - ML_price (Optional[float]): Default price to use the managed lane. Default is 0. - ML_access (Optional[Any]): User-defined method to note access restrictions for the managed lane. Default is None. - ML_access_point (Optional[bool]): If the link is an access point for the managed lane. Default is False. - ML_egress_point (Optional[bool]): If the link is an egress point for the managed lane. Default is False. - sc_ML_lanes (Optional[list[dict]]): List of scoped link values for the number of lanes on the managed lane. + ML_lanes (int | None): Default number of lanes on the managed lane. Default is None. + ML_price (float | None): Default price to use the managed lane. Default is 0. + ML_access (Any | None): User-defined method to note access restrictions for the managed lane. Default is None. + ML_access_point (bool | None): If the link is an access point for the managed lane. Default is False. + ML_egress_point (bool | None): If the link is an egress point for the managed lane. Default is False. + sc_ML_lanes (list[dict] | None): List of scoped link values for the number of lanes on the managed lane. Default is None. - sc_ML_price (Optional[list[dict]]): List of scoped link values for the price of the managed lane. Default is None. - sc_ML_access (Optional[list[dict]]): List of scoped link values for the access restrictions of the managed lane. + sc_ML_price (list[dict] | None): List of scoped link values for the price of the managed lane. Default is None. + sc_ML_access (list[dict] | None): List of scoped link values for the access restrictions of the managed lane. Default is None. - ML_geometry (Optional[GeoSeries]): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. + ML_geometry (GeoSeries | None): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. Simple A-->B geometry of the managed lane. Default is None. - ML_shape_id (Optional[str]): Identifier referencing the primary key of the shapes table for the managed lane. + ML_shape_id (str | None): Identifier referencing the primary key of the shapes table for the managed lane. Default is None. - osm_link_id (Optional[str]): Identifier referencing the OSM link ID. Default is "". - GP_A (Optional[int]): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. + osm_link_id (str | None): Identifier referencing the OSM link ID. Default is "". + GP_A (int | None): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. Identifier referencing the primary key of the associated general purpose link start node for a managed lane link in a model network. Default is None. - GP_B (Optional[int]): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. + GP_B (int | None): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. Identifier referencing the primary key of the associated general purpose link end node for a managed lane link in a model network. Default is None. + GP_B (int | None): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. !!! tip "User Defined Properties" @@ -227,7 +226,7 @@ class RoadLinksTable(DataFrameModel): """ model_link_id: Series[int] = Field(coerce=True, unique=True) - model_link_id_idx: Optional[Series[int]] = Field(coerce=True, unique=True) + model_link_id_idx: Series[int] | None = Field(coerce=True, unique=True) A: Series[int] = Field(nullable=False, coerce=True) B: Series[int] = Field(nullable=False, coerce=True) geometry: GeoSeries = Field(nullable=False) @@ -248,54 +247,54 @@ class RoadLinksTable(DataFrameModel): price: Series[float] = Field(coerce=True, nullable=False, default=0) # Optional Fields - ref: Optional[Series[str]] = Field(coerce=True, nullable=True, default=None) - access: Optional[Series[Any]] = Field(coerce=True, nullable=True, default=None) + ref: Series[str] | None = Field(coerce=True, nullable=True, default=None) + access: Series[Any] | None = Field(coerce=True, nullable=True, default=None) - sc_lanes: Optional[Series[object]] = Field(coerce=True, nullable=True, default=None) - sc_price: Optional[Series[object]] = Field(coerce=True, nullable=True, default=None) + sc_lanes: Series[object] | None = Field(coerce=True, nullable=True, default=None) + sc_price: Series[object] | None = Field(coerce=True, nullable=True, default=None) ML_projects: Series[str] = Field(coerce=True, default="") - ML_lanes: Optional[Series[Int64]] = Field(coerce=True, nullable=True, default=None) - ML_price: Optional[Series[float]] = Field(coerce=True, nullable=True, default=0) - ML_access: Optional[Series[Any]] = Field(coerce=True, nullable=True, default=True) - ML_access_point: Optional[Series[bool]] = Field( + ML_lanes: Series[Int64] | None = Field(coerce=True, nullable=True, default=None) + ML_price: Series[float] | None = Field(coerce=True, nullable=True, default=0) + ML_access: Series[Any] | None = Field(coerce=True, nullable=True, default=True) + ML_access_point: Series[bool] | None = Field( coerce=True, default=False, ) - ML_egress_point: Optional[Series[bool]] = Field( + ML_egress_point: Series[bool] | None = Field( coerce=True, default=False, ) - sc_ML_lanes: Optional[Series[object]] = Field( + sc_ML_lanes: Series[object] | None = Field( coerce=True, nullable=True, default=None, ) - sc_ML_price: Optional[Series[object]] = Field( + sc_ML_price: Series[object] | None = Field( coerce=True, nullable=True, default=None, ) - sc_ML_access: Optional[Series[object]] = Field( + sc_ML_access: Series[object] | None = Field( coerce=True, nullable=True, default=None, ) - ML_geometry: Optional[GeoSeries] = Field(nullable=True, coerce=True, default=None) - ML_shape_id: Optional[Series[str]] = Field(nullable=True, coerce=True, default=None) + ML_geometry: GeoSeries | None = Field(nullable=True, coerce=True, default=None) + ML_shape_id: Series[str] | None = Field(nullable=True, coerce=True, default=None) - truck_access: Optional[Series[bool]] = Field(coerce=True, nullable=True, default=True) + truck_access: Series[bool] | None = Field(coerce=True, nullable=True, default=True) osm_link_id: Series[str] = Field(coerce=True, nullable=True, default="") # todo this should be List[dict] but ranch output something else so had to have it be Any. - locationReferences: Optional[Series[Any]] = Field( + locationReferences: Series[Any] | None = Field( coerce=True, nullable=True, default="", ) - GP_A: Optional[Series[Int64]] = Field(coerce=True, nullable=True, default=None) - GP_B: Optional[Series[Int64]] = Field(coerce=True, nullable=True, default=None) + GP_A: Series[Int64] | None = Field(coerce=True, nullable=True, default=None) + GP_B: Series[Int64] | None = Field(coerce=True, nullable=True, default=None) class Config: """Config for RoadLinksTable.""" @@ -334,14 +333,14 @@ class RoadNodesTable(DataFrameModel): Attributes: model_node_id (int): Unique identifier for the node. - osm_node_id (Optional[str]): Reference to open street map node id. Used for querying. Not guaranteed to be unique. + osm_node_id (str | None): Reference to open street map node id. Used for querying. Not guaranteed to be unique. X (float): Longitude of the node in WGS84. Must be in the range of -180 to 180. Y (float): Latitude of the node in WGS84. Must be in the range of -90 to 90. geometry (GeoSeries): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. """ model_node_id: Series[int] = Field(coerce=True, unique=True, nullable=False) - model_node_idx: Optional[Series[int]] = Field(coerce=True, unique=True, nullable=False) + model_node_idx: Series[int] | None = Field(coerce=True, unique=True, nullable=False) X: Series[float] = Field(coerce=True, nullable=False) Y: Series[float] = Field(coerce=True, nullable=False) geometry: GeoSeries @@ -353,8 +352,8 @@ class RoadNodesTable(DataFrameModel): default="", ) projects: Series[str] = Field(coerce=True, default="") - inboundReferenceIds: Optional[Series[list[str]]] = Field(coerce=True, nullable=True) - outboundReferenceIds: Optional[Series[list[str]]] = Field(coerce=True, nullable=True) + inboundReferenceIds: Series[list[str]] | None = Field(coerce=True, nullable=True) + outboundReferenceIds: Series[list[str]] | None = Field(coerce=True, nullable=True) class Config: """Config for RoadNodesTable.""" @@ -384,15 +383,15 @@ class RoadShapesTable(DataFrameModel): shape_id (str): Unique identifier for the shape. geometry (GeoSeries): **Warning**: this attribute is controlled by wrangler and should not be explicitly user-edited. Geometry of the shape. - ref_shape_id (Optional[str]): Reference to another `shape_id` that it may + ref_shape_id (str | None): Reference to another `shape_id` that it may have been created from. Default is None. """ shape_id: Series[str] = Field(unique=True) - shape_id_idx: Optional[Series[int]] = Field(unique=True) + shape_id_idx: Series[int] | None = Field(unique=True) geometry: GeoSeries = Field() - ref_shape_id: Optional[Series] = Field(nullable=True) + ref_shape_id: Series | None = Field(nullable=True) class Config: """Config for RoadShapesTable.""" diff --git a/network_wrangler/models/roadway/types.py b/network_wrangler/models/roadway/types.py index 6ae6dde4..c5ffe28a 100644 --- a/network_wrangler/models/roadway/types.py +++ b/network_wrangler/models/roadway/types.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import datetime -from typing import ClassVar, Optional, Union +from typing import ClassVar from pydantic import ( BaseModel, @@ -51,9 +51,9 @@ class ScopedLinkValueItem(RecordModel): require_any_of: ClassVar[AnyOf] = [["category", "timespan"]] model_config = ConfigDict(extra="forbid") - category: Optional[Union[str, int]] = Field(default=DEFAULT_CATEGORY) - timespan: Optional[list[TimeString]] = Field(default=DEFAULT_TIMESPAN) - value: Union[int, float, str] + category: str | int | None = Field(default=DEFAULT_CATEGORY) + timespan: list[TimeString] | None = Field(default=DEFAULT_TIMESPAN) + value: int | float | str @property def timespan_dt(self) -> list[list[datetime]]: diff --git a/network_wrangler/roadway/clip.py b/network_wrangler/roadway/clip.py index 775a4201..6a054ff0 100644 --- a/network_wrangler/roadway/clip.py +++ b/network_wrangler/roadway/clip.py @@ -24,7 +24,7 @@ import copy from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import geopandas as gpd @@ -40,8 +40,8 @@ def clip_roadway_to_dfs( network: RoadwayNetwork, boundary_gdf: gpd.GeoDataFrame = None, - boundary_geocode: Optional[Union[str, dict]] = None, - boundary_file: Optional[Union[str, Path]] = None, + boundary_geocode: str | dict | None = None, + boundary_file: str | Path | None = None, ) -> tuple: """Clips a RoadwayNetwork object to a boundary and returns the resulting GeoDataFrames. @@ -97,8 +97,8 @@ def clip_roadway_to_dfs( def clip_roadway( network: RoadwayNetwork, boundary_gdf: gpd.GeoDataFrame = None, - boundary_geocode: Optional[Union[str, dict]] = None, - boundary_file: Optional[Union[str, Path]] = None, + boundary_geocode: str | dict | None = None, + boundary_file: str | Path | None = None, ) -> RoadwayNetwork: """Clip a RoadwayNetwork object to a boundary. @@ -124,7 +124,7 @@ def clip_roadway( boundary_geocode=boundary_geocode, boundary_file=boundary_file, ) - from .network import RoadwayNetwork # noqa: PLC0415 + from .network import RoadwayNetwork trimmed_net = RoadwayNetwork( links_df=trimmed_links_df, diff --git a/network_wrangler/roadway/graph.py b/network_wrangler/roadway/graph.py index 18607b88..2749b784 100644 --- a/network_wrangler/roadway/graph.py +++ b/network_wrangler/roadway/graph.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import networkx as nx import osmnx as ox @@ -166,7 +166,7 @@ def links_nodes_to_ox_graph( return G -def net_to_graph(net: RoadwayNetwork, mode: Optional[str] = None) -> nx.MultiDiGraph: +def net_to_graph(net: RoadwayNetwork, mode: str | None = None) -> nx.MultiDiGraph: """Converts a network to a MultiDiGraph. Args: @@ -185,9 +185,7 @@ def net_to_graph(net: RoadwayNetwork, mode: Optional[str] = None) -> nx.MultiDiG return G -def shortest_path( - G: nx.MultiDiGraph, O_id, D_id, sp_weight_property="weight" -) -> Union[list, None]: +def shortest_path(G: nx.MultiDiGraph, O_id, D_id, sp_weight_property="weight") -> list | None: """Calculates the shortest path between two nodes in a network. Args: diff --git a/network_wrangler/roadway/io.py b/network_wrangler/roadway/io.py index 3ef317fc..b13adaca 100644 --- a/network_wrangler/roadway/io.py +++ b/network_wrangler/roadway/io.py @@ -3,10 +3,8 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING -import pyarrow as pa -import pyarrow.parquet as pq from geopandas import GeoDataFrame from ..configs import ConfigInputTypes, DefaultConfig, WranglerConfig, load_wrangler_config @@ -26,13 +24,13 @@ def load_roadway( links_file: Path, nodes_file: Path, - shapes_file: Optional[Path] = None, + shapes_file: Path | None = None, in_crs: int = LAT_LON_CRS, read_in_shapes: bool = False, - boundary_gdf: Optional[GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, - filter_links_to_nodes: Optional[bool] = None, + boundary_gdf: GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, + filter_links_to_nodes: bool | None = None, config: ConfigInputTypes = DefaultConfig, ) -> RoadwayNetwork: """Reads a network from the roadway network standard. @@ -63,7 +61,7 @@ def load_roadway( Returns: a RoadwayNetwork instance """ - from .network import RoadwayNetwork # noqa: PLC0415 + from .network import RoadwayNetwork if not isinstance(config, WranglerConfig): config = load_wrangler_config(config) @@ -121,8 +119,8 @@ def load_roadway( def id_roadway_file_paths_in_dir( - dir: Union[Path, str], file_format: RoadwayFileTypes = "geojson" -) -> tuple[Path, Path, Union[None, Path]]: + dir: Path | str, file_format: RoadwayFileTypes = "geojson" +) -> tuple[Path, Path, None | Path]: """Identifies the paths to the links, nodes, and shapes files in a directory.""" network_path = Path(dir) if not network_path.is_dir(): @@ -155,13 +153,13 @@ def id_roadway_file_paths_in_dir( def load_roadway_from_dir( - dir: Union[Path, str], + dir: Path | str, file_format: RoadwayFileTypes = "geojson", read_in_shapes: bool = False, - boundary_gdf: Optional[GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, - filter_links_to_nodes: Optional[bool] = None, + boundary_gdf: GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, + filter_links_to_nodes: bool | None = None, config: ConfigInputTypes = DefaultConfig, ) -> RoadwayNetwork: """Reads a network from the roadway network standard. @@ -204,8 +202,8 @@ def load_roadway_from_dir( def write_roadway( - net: Union[RoadwayNetwork, ModelRoadwayNetwork], - out_dir: Union[Path, str] = ".", + net: RoadwayNetwork | ModelRoadwayNetwork, + out_dir: Path | str = ".", convert_complex_link_properties_to_single_field: bool = False, prefix: str = "", file_format: RoadwayFileTypes = "geojson", @@ -257,10 +255,10 @@ def convert_roadway_file_serialization( out_format: RoadwayFileTypes = "parquet", out_prefix: str = "", overwrite: bool = True, - boundary_gdf: Optional[GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, - chunk_size: Optional[int] = None, + boundary_gdf: GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, + chunk_size: int | None = None, ): """Converts a files in a roadway from one serialization format to another without parsing. @@ -283,7 +281,7 @@ def convert_roadway_file_serialization( Chunking will only apply to converting from json to parquet files. """ links_in_file, nodes_in_file, shapes_in_file = id_roadway_file_paths_in_dir(in_path, in_format) - from ..utils.io_table import convert_file_serialization # noqa: PLC0415 + from ..utils.io_table import convert_file_serialization nodes_out_file = Path(out_dir / f"{out_prefix}_nodes.{out_format}") convert_file_serialization( @@ -327,15 +325,15 @@ def convert_roadway_file_serialization( def convert_roadway_network_serialization( - input_path: Union[str, Path], + input_path: str | Path, output_format: RoadwayFileTypes = "geojson", - out_dir: Union[str, Path] = ".", + out_dir: str | Path = ".", input_file_format: RoadwayFileTypes = "geojson", out_prefix: str = "", overwrite: bool = True, - boundary_gdf: Optional[GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, + boundary_gdf: GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, filter_links_to_nodes: bool = False, ): """Converts a roadway network from one serialization format to another with parsing. diff --git a/network_wrangler/roadway/links/create.py b/network_wrangler/roadway/links/create.py index b6d14b34..ad3daed1 100644 --- a/network_wrangler/roadway/links/create.py +++ b/network_wrangler/roadway/links/create.py @@ -2,7 +2,6 @@ import copy import time -from typing import Optional, Union import geopandas as gpd import pandas as pd @@ -36,7 +35,7 @@ def shape_id_from_link_geometry( def _fill_missing_link_geometries_from_nodes( - links_df: pd.DataFrame, nodes_df: Optional[DataFrame[RoadNodesTable]] = None + links_df: pd.DataFrame, nodes_df: DataFrame[RoadNodesTable] | None = None ) -> gpd.GeoDataFrame: """Create location references and link geometries from nodes. @@ -70,9 +69,9 @@ def _fill_missing_distance_from_geometry(df: gpd.GeoDataFrame) -> gpd.GeoDataFra @validate_call_pyd def data_to_links_df( - links_df: Union[pd.DataFrame, list[dict]], + links_df: pd.DataFrame | list[dict], in_crs: int = LAT_LON_CRS, - nodes_df: Union[None, DataFrame[RoadNodesTable]] = None, + nodes_df: None | DataFrame[RoadNodesTable] = None, ) -> DataFrame[RoadLinksTable]: """Create a links dataframe from list of link properties + link geometries or associated nodes. @@ -126,11 +125,11 @@ def copy_links( links_df: DataFrame[RoadLinksTable], link_id_lookup: dict[int, int], node_id_lookup: dict[int, int], - updated_geometry_col: Optional[str] = None, - nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + updated_geometry_col: str | None = None, + nodes_df: DataFrame[RoadNodesTable] | None = None, offset_meters: float = -5, - copy_properties: Optional[list[str]] = None, - rename_properties: Optional[dict[str, str]] = None, + copy_properties: list[str] | None = None, + rename_properties: dict[str, str] | None = None, name_prefix: str = "copy of", validate: bool = True, ) -> DataFrame[RoadLinksTable]: diff --git a/network_wrangler/roadway/links/delete.py b/network_wrangler/roadway/links/delete.py index 49052fbf..54e73dbd 100644 --- a/network_wrangler/roadway/links/delete.py +++ b/network_wrangler/roadway/links/delete.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pandera.typing import DataFrame @@ -10,7 +10,6 @@ from ...logger import WranglerLogger from ...models.roadway.tables import RoadLinksTable from ...transit.validate import shape_links_without_road_links -from ...utils.models import validate_call_pyd if TYPE_CHECKING: from ...transit.network import TransitNetwork @@ -20,7 +19,7 @@ def delete_links_by_ids( links_df: DataFrame[RoadLinksTable], del_link_ids: list[int], ignore_missing: bool = False, - transit_net: Optional[TransitNetwork] = None, + transit_net: TransitNetwork | None = None, ) -> DataFrame[RoadLinksTable]: """Delete links from a links table. diff --git a/network_wrangler/roadway/links/df_accessors.py b/network_wrangler/roadway/links/df_accessors.py index 47ba08fd..044834d9 100644 --- a/network_wrangler/roadway/links/df_accessors.py +++ b/network_wrangler/roadway/links/df_accessors.py @@ -3,6 +3,7 @@ import pandas as pd from pandera.typing import DataFrame +from ...errors import NotLinksError from ...logger import WranglerLogger from ...models.roadway.tables import RoadLinksTable, RoadShapesTable from .filters import ( @@ -21,7 +22,6 @@ filter_links_transit_only, ) from .geo import true_shape -from .links import NotLinksError from .summary import link_summary diff --git a/network_wrangler/roadway/links/edit.py b/network_wrangler/roadway/links/edit.py index d5b06e80..4a68c6e8 100644 --- a/network_wrangler/roadway/links/edit.py +++ b/network_wrangler/roadway/links/edit.py @@ -21,7 +21,7 @@ from __future__ import annotations import copy -from typing import Any, Literal, Optional, Union +from typing import Any, Literal import geopandas as gpd import numpy as np @@ -94,7 +94,7 @@ def _resolve_conflicting_scopes( def _valid_default_value_for_change(value: Any) -> bool: - if isinstance(value, (int, np.integer)): + if isinstance(value, int | np.integer): return True return bool(isinstance(value, float)) @@ -129,7 +129,7 @@ def _update_property_for_scope( @validate_call(config={"arbitrary_types_allowed": True}, validate_return=True) def _edit_scoped_link_property( - scoped_prop_value_list: Union[None, list[ScopedLinkValueItem]], + scoped_prop_value_list: None | list[ScopedLinkValueItem], scoped_prop_set: ScopedPropertySetList, default_value: Any = None, overwrite_scoped: Literal["conflicting", "all", "error"] = "error", @@ -231,7 +231,7 @@ def _edit_link_property( link_idx: list[int], prop_name: str, prop_change: dict, - project_name: Optional[str] = None, + project_name: str | None = None, config: WranglerConfig = DefaultConfig, ) -> DataFrame[RoadLinksTable]: """Return edited (in place) RoadLinksTable with property changes for a list of links. @@ -351,7 +351,7 @@ def edit_link_properties( links_df: DataFrame[RoadLinksTable], link_idx: list, property_changes: dict[str, dict], - project_name: Optional[str] = None, + project_name: str | None = None, config: WranglerConfig = DefaultConfig, ) -> DataFrame[RoadLinksTable]: """Return copy of RoadLinksTable with edited link properties for a list of links. diff --git a/network_wrangler/roadway/links/filters.py b/network_wrangler/roadway/links/filters.py index 0c29f806..186418a0 100644 --- a/network_wrangler/roadway/links/filters.py +++ b/network_wrangler/roadway/links/filters.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import pandas as pd @@ -114,7 +114,7 @@ def filter_links_transit_only( def filter_links_to_modes( - links_df: DataFrame[RoadLinksTable], modes: Union[str, list[str]] + links_df: DataFrame[RoadLinksTable], modes: str | list[str] ) -> DataFrame[RoadLinksTable]: """Filters links dataframe to only include links that are accessible by the modes in the list. @@ -162,7 +162,7 @@ def filter_links_to_ids( def filter_links_not_in_ids( - links_df: DataFrame[RoadLinksTable], link_ids: Union[list[int], pd.Series] + links_df: DataFrame[RoadLinksTable], link_ids: list[int] | pd.Series ) -> DataFrame[RoadLinksTable]: """Filters links dataframe to NOT have link_ids.""" return links_df.loc[~links_df["model_link_id"].isin(link_ids)] diff --git a/network_wrangler/roadway/links/io.py b/network_wrangler/roadway/links/io.py index b6393882..4655e806 100644 --- a/network_wrangler/roadway/links/io.py +++ b/network_wrangler/roadway/links/io.py @@ -4,7 +4,6 @@ import time from pathlib import Path -from typing import Union import pandas as pd from pandera.typing import DataFrame @@ -13,7 +12,7 @@ from ...logger import WranglerLogger from ...models._base.types import GeoFileTypes from ...models.roadway.converters import translate_links_df_v1_to_v0 -from ...models.roadway.tables import RoadLinksAttrs, RoadLinksTable, RoadNodesAttrs, RoadNodesTable +from ...models.roadway.tables import RoadLinksTable, RoadNodesTable from ...params import LAT_LON_CRS from ...utils.io_table import read_table, write_table from ...utils.models import order_fields_from_data_model, validate_call_pyd @@ -70,7 +69,7 @@ def read_links( @validate_call_pyd def write_links( links_df: DataFrame[RoadLinksTable], - out_dir: Union[str, Path] = ".", + out_dir: str | Path = ".", convert_complex_properties_to_single_field: bool = False, prefix: str = "", file_format: GeoFileTypes = "json", diff --git a/network_wrangler/roadway/links/links.py b/network_wrangler/roadway/links/links.py index 2d6d6771..9fab57ee 100644 --- a/network_wrangler/roadway/links/links.py +++ b/network_wrangler/roadway/links/links.py @@ -1,11 +1,8 @@ """Functions for querying RoadLinksTable.""" -from typing import Optional - import pandas as pd from pandera.typing import DataFrame -from ...errors import LinkNotFoundError, MissingNodesError, NotLinksError from ...logger import WranglerLogger from ...models.roadway.tables import RoadLinksTable, RoadNodesTable, RoadShapesTable from ...utils.data import fk_in_pk @@ -14,7 +11,7 @@ def node_ids_in_links( - links_df: DataFrame[RoadLinksTable], nodes_df: Optional[DataFrame[RoadNodesTable]] = None + links_df: DataFrame[RoadLinksTable], nodes_df: DataFrame[RoadNodesTable] | None = None ) -> pd.Series: """Returns the unique node_ids in a links dataframe. @@ -36,7 +33,7 @@ def node_ids_in_links( def node_ids_in_link_ids( link_ids: list[int], links_df: DataFrame[RoadLinksTable], - nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + nodes_df: DataFrame[RoadNodesTable] | None = None, ) -> pd.Series: """Returns the unique node_ids in a list of link_ids. @@ -54,7 +51,7 @@ def node_ids_in_link_ids( def node_ids_unique_to_link_ids( link_ids: list[int], links_df: DataFrame[RoadLinksTable], - nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + nodes_df: DataFrame[RoadNodesTable] | None = None, ) -> list[int]: """Returns the unique node_ids in a list of link_ids that are not in other links.""" selected_link_node_ids = node_ids_in_link_ids(link_ids, links_df, nodes_df=nodes_df) @@ -65,7 +62,7 @@ def node_ids_unique_to_link_ids( def shape_ids_in_links( - links_df: DataFrame[RoadLinksTable], shapes_df: Optional[DataFrame[RoadShapesTable]] = None + links_df: DataFrame[RoadLinksTable], shapes_df: DataFrame[RoadShapesTable] | None = None ) -> list[int]: """Returns the unique shape_ids in a links dataframe. @@ -86,7 +83,7 @@ def shape_ids_in_links( def shape_ids_in_link_ids( link_ids: list[int], links_df: DataFrame[RoadLinksTable], - shapes_df: Optional[DataFrame[RoadShapesTable]] = None, + shapes_df: DataFrame[RoadShapesTable] | None = None, ) -> list[int]: """Returns the unique shape_ids in a list of link_ids.""" _links_df = filter_links_to_ids(links_df, link_ids) @@ -96,7 +93,7 @@ def shape_ids_in_link_ids( def shape_ids_unique_to_link_ids( link_ids: list[int], links_df: DataFrame[RoadLinksTable], - shapes_df: Optional[DataFrame[RoadShapesTable]] = None, + shapes_df: DataFrame[RoadShapesTable] | None = None, ) -> list[int]: """Returns the unique shape_ids in a list of link_ids.""" selected_link_shape_ids = shape_ids_in_link_ids(link_ids, links_df, shapes_df=shapes_df) diff --git a/network_wrangler/roadway/links/scopes.py b/network_wrangler/roadway/links/scopes.py index 64975ed3..d32003ec 100644 --- a/network_wrangler/roadway/links/scopes.py +++ b/network_wrangler/roadway/links/scopes.py @@ -32,21 +32,20 @@ """ +from __future__ import annotations + import copy -from typing import Any, TypeGuard, Union +from typing import Any, TypeGuard import pandas as pd from pandera.typing import DataFrame from pydantic import validate_call -from typing_extensions import TypeGuard -from ...errors import InvalidScopedLinkValue from ...logger import WranglerLogger from ...models._base.types import TimeString from ...models.projects.roadway_changes import IndivScopedPropertySetItem from ...models.roadway.tables import ( ExplodedScopedLinkPropertyTable, - RoadLinksAttrs, RoadLinksTable, ) from ...models.roadway.types import ScopedLinkValueItem @@ -62,7 +61,7 @@ def _convert_to_scoped_items( - scoped_values: list[Union[ScopedLinkValueItem, dict]], + scoped_values: list[ScopedLinkValueItem | dict], ) -> list[ScopedLinkValueItem]: """Convert dictionaries to ScopedLinkValueItem objects if needed.""" converted = [] @@ -75,7 +74,7 @@ def _convert_to_scoped_items( def _filter_to_matching_timespan_scopes( - scoped_values: list[Union[ScopedLinkValueItem, dict]], timespan: list[TimeString] + scoped_values: list[ScopedLinkValueItem | dict], timespan: list[TimeString] ) -> list[ScopedLinkValueItem]: """Filters list of ScopedLinkValueItems to list of those that match. @@ -88,20 +87,16 @@ def _filter_to_matching_timespan_scopes( if timespan == DEFAULT_TIMESPAN: return scoped_values times_dt = list(map(str_to_time, timespan)) - # typeguard - mypy suggested this b/c we cannot guarantee we got rid of all the Nones return [ s for s in scoped_values - if ( - _islist(s.timespan) - and dt_contains([str_to_time(i) for i in _islist(s.timespan)], times_dt) - ) + if (_islist(s.timespan) and dt_contains([str_to_time(i) for i in s.timespan], times_dt)) or s.timespan == DEFAULT_TIMESPAN ] def _filter_to_matching_category_scopes( - scoped_values: list[Union[ScopedLinkValueItem, dict]], category: Union[str, list] + scoped_values: list[ScopedLinkValueItem | dict], category: str | list ) -> list[ScopedLinkValueItem]: """Filters list of ScopedLinkValueItems to list of those that match. @@ -115,8 +110,8 @@ def _filter_to_matching_category_scopes( def _filter_to_matching_scope( - scoped_values: list[Union[ScopedLinkValueItem, dict]], - category: Union[str, list] = DEFAULT_CATEGORY, + scoped_values: list[ScopedLinkValueItem | dict], + category: str | list = DEFAULT_CATEGORY, timespan: list[TimeString] = DEFAULT_TIMESPAN, ) -> tuple[list[ScopedLinkValueItem], list[ScopedLinkValueItem]]: """Filters list of ScopedLinkValueItems to list of those that match. @@ -133,7 +128,7 @@ def _filter_to_matching_scope( def _filter_to_overlapping_timespan_scopes( - scoped_values: list[Union[ScopedLinkValueItem, dict]], timespan: list[TimeString] + scoped_values: list[ScopedLinkValueItem | dict], timespan: list[TimeString] ) -> list[ScopedLinkValueItem]: """Filters list of ScopedLinkValueItems to list of those that overlap. @@ -146,36 +141,30 @@ def _filter_to_overlapping_timespan_scopes( if timespan == DEFAULT_TIMESPAN: return scoped_values q_timespan_dt = list(map(str_to_time, timespan)) - # typeguard - mypy suggested this b/c we cannot guarantee we got rid of all the Nones return [ s for s in scoped_values if ( _islist(s.timespan) - and dt_list_overlaps([q_timespan_dt, [str_to_time(i) for i in _islist(s.timespan)]]) + and dt_list_overlaps([q_timespan_dt, [str_to_time(i) for i in s.timespan]]) ) or s.timespan == DEFAULT_TIMESPAN ] -def _islist(s: Any) -> TypeGuard[list]: - """Typeguard for list to make mypy not complain.""" - if s is list: - return s - if isinstance(s, list): - return s # type: ignore # noqa: PGH003 - is_list = bool(issubclass(type(s), list)) - if is_list: - return s - msg = f"{s} is not a list but is required to be one." - raise TypeError(msg) +def _islist(s: Any) -> TypeGuard[list[str]]: + """Type guard for list to make mypy not complain. + + Returns True if s is a list, allowing mypy to narrow the type. + """ + return isinstance(s, list) def _filter_to_overlapping_scopes( - scoped_prop_list: list[Union[ScopedLinkValueItem, IndivScopedPropertySetItem, dict]], - category: Union[str, list] = DEFAULT_CATEGORY, + scoped_prop_list: list[ScopedLinkValueItem | IndivScopedPropertySetItem | dict], + category: str | list = DEFAULT_CATEGORY, timespan: list[TimeString] = DEFAULT_TIMESPAN, -) -> list[Union[ScopedLinkValueItem, IndivScopedPropertySetItem]]: +) -> list[ScopedLinkValueItem | IndivScopedPropertySetItem]: """Filter a list of IndivScopedPropertySetItem and ScopedLinkValueItems to a specific scope. Defaults are considered to overlap everything in their scope dimension. @@ -198,7 +187,7 @@ def _filter_to_overlapping_scopes( def _filter_to_conflicting_timespan_scopes( - scoped_values: list[Union[ScopedLinkValueItem, dict]], timespan: list[TimeString] + scoped_values: list[ScopedLinkValueItem | dict], timespan: list[TimeString] ) -> list[ScopedLinkValueItem]: """Filters scoped values to only include those that conflict with the timespan. @@ -222,9 +211,9 @@ def _filter_to_conflicting_timespan_scopes( def _filter_to_conflicting_scopes( - scoped_values: list[Union[ScopedLinkValueItem, dict]], + scoped_values: list[ScopedLinkValueItem | dict], timespan: list[TimeString], - category: Union[str, list[str]], + category: str | list[str], ) -> list[ScopedLinkValueItem]: """Filters scoped values to only include those that conflict with the timespan. @@ -299,11 +288,14 @@ def _create_exploded_df_for_scoped_prop( return exp_df -# @validate_call(config={"arbitrary_types_allowed": True}) +@validate_call( + config={"arbitrary_types_allowed": True}, + validate_return=False, +) def _filter_exploded_df_to_scope( - exp_scoped_prop_df: DataFrame[ExplodedScopedLinkPropertyTable], + exp_scoped_prop_df: pd.DataFrame, timespan: list[TimeString] = DEFAULT_TIMESPAN, - category: Union[str, int] = DEFAULT_CATEGORY, + category: str | int = DEFAULT_CATEGORY, strict_timespan_match: bool = False, min_overlap_minutes: int = 60, ) -> pd.DataFrame: @@ -343,10 +335,10 @@ def _filter_exploded_df_to_scope( @validate_call_pyd def prop_for_scope( - links_df: DataFrame[RoadLinksTable], + links_df: pd.DataFrame, prop_name: str, - timespan: Union[None, list[TimeString]] = DEFAULT_TIMESPAN, - category: Union[str, int, None] = DEFAULT_CATEGORY, + timespan: None | list[TimeString] = DEFAULT_TIMESPAN, + category: str | int | None = DEFAULT_CATEGORY, strict_timespan_match: bool = False, min_overlap_minutes: int = 60, allow_default: bool = True, diff --git a/network_wrangler/roadway/links/validate.py b/network_wrangler/roadway/links/validate.py index ff60ee7d..fcecede4 100644 --- a/network_wrangler/roadway/links/validate.py +++ b/network_wrangler/roadway/links/validate.py @@ -1,12 +1,8 @@ """Utilities for validating a RoadLinksTable beyond its data model.""" from pathlib import Path -from typing import Optional -import ijson import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq from ...errors import NodesInLinksMissingError from ...logger import WranglerLogger @@ -31,7 +27,7 @@ def validate_links_have_nodes(links_df: pd.DataFrame, nodes_df: pd.DataFrame) -> def validate_links_file( links_filename: Path, - nodes_df: Optional[pd.DataFrame] = None, + nodes_df: pd.DataFrame | None = None, strict: bool = False, errors_filename: Path = Path("link_errors.csv"), ) -> bool: @@ -55,7 +51,7 @@ def validate_links_file( def validate_links_df( links_df: pd.DataFrame, - nodes_df: Optional[pd.DataFrame] = None, + nodes_df: pd.DataFrame | None = None, strict: bool = False, errors_filename: Path = Path("link_errors.csv"), ) -> bool: @@ -71,13 +67,13 @@ def validate_links_df( Returns: bool: True if the links dataframe is valid. """ - from ...models.roadway.tables import RoadLinksTable # noqa: PLC0415 - from ...utils.models import TableValidationError, validate_df_to_model # noqa: PLC0415 + from ...models.roadway.tables import RoadLinksTable + from ...utils.models import TableValidationError, validate_df_to_model is_valid = True if not strict: - from .create import data_to_links_df # noqa: PLC0415 + from .create import data_to_links_df try: links_df = data_to_links_df(links_df) diff --git a/network_wrangler/roadway/model_roadway.py b/network_wrangler/roadway/model_roadway.py index 08b8f425..e5365733 100644 --- a/network_wrangler/roadway/model_roadway.py +++ b/network_wrangler/roadway/model_roadway.py @@ -4,7 +4,7 @@ import copy from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import geopandas as gpd import pandas as pd @@ -83,15 +83,15 @@ class ModelRoadwayNetwork: managed lane counterparts. ml_node_id_lookup: lookup from general purpose node ids to node ids of their managed lane counterparts. - _net_hash: hash of the the input links and nodes in order to detect changes. + _net_version: modification version of the input network when this was created. """ def __init__( self, net, - ml_link_id_lookup: Optional[dict[int, int]] = None, - ml_node_id_lookup: Optional[dict[int, int]] = None, + ml_link_id_lookup: dict[int, int] | None = None, + ml_node_id_lookup: dict[int, int] | None = None, ): """Constructor for ModelRoadwayNetwork. @@ -147,7 +147,7 @@ def __init__( self.links_df, self.nodes_df = model_links_nodes_from_net( self.net, self.ml_link_id_lookup, self.ml_node_id_lookup ) - self._net_hash = copy.deepcopy(net.network_hash) + self._net_version = net.modification_version @property def ml_config(self) -> dict: @@ -237,7 +237,7 @@ def _generate_ml_link_id_lookup_from_range(links_df, link_id_range: tuple[int]): available for provided range: {link_id_range}." raise ValueError(msg) new_link_ids = list(avail_ml_link_ids)[: len(og_ml_link_ids)] - return dict(zip(og_ml_link_ids, new_link_ids)) + return dict(zip(og_ml_link_ids, new_link_ids, strict=True)) def _generate_ml_node_id_from_range(nodes_df, links_df, node_id_range: tuple[int]): @@ -249,7 +249,7 @@ def _generate_ml_node_id_from_range(nodes_df, links_df, node_id_range: tuple[int available for provided range: {node_id_range}." raise ValueError(msg) new_ml_node_ids = list(avail_ml_node_ids)[: len(og_ml_node_ids)] - return dict(zip(og_ml_node_ids.tolist(), new_ml_node_ids)) + return dict(zip(og_ml_node_ids.tolist(), new_ml_node_ids, strict=True)) def _generate_ml_link_id_lookup_from_scalar(links_df: DataFrame[RoadLinksTable], scalar: int): @@ -259,7 +259,7 @@ def _generate_ml_link_id_lookup_from_scalar(links_df: DataFrame[RoadLinksTable], if links_df.model_link_id.isin(link_id_list).any(): msg = f"New link ids generated by scalar {scalar} already exist. Try a different scalar." raise ValueError(msg) - return dict(zip(og_ml_link_ids, link_id_list)) + return dict(zip(og_ml_link_ids, link_id_list, strict=True)) def _generate_ml_node_id_lookup_from_scalar(nodes_df, links_df, scalar: int): @@ -269,7 +269,7 @@ def _generate_ml_node_id_lookup_from_scalar(nodes_df, links_df, scalar: int): if nodes_df.model_node_id.isin(node_id_list).any(): msg = f"New node ids generated by scalar {scalar} already exist. Try a different scalar." raise ValueError(msg) - return dict(zip(og_ml_node_ids.tolist(), node_id_list.tolist())) + return dict(zip(og_ml_node_ids.tolist(), node_id_list.tolist(), strict=True)) def model_links_nodes_from_net( @@ -382,7 +382,7 @@ def _create_separate_managed_lane_links( } ml_props = filter_link_properties_managed_lanes(links_df) - ml_rename_props = dict(zip(ml_props, strip_ML_from_prop_list(ml_props))) + ml_rename_props = dict(zip(ml_props, strip_ML_from_prop_list(ml_props), strict=True)) ml_links_df = copy_links( links_df.of_type.managed, @@ -469,7 +469,7 @@ def _create_dummy_connector_links( # 3 - Determine property values access_egress_df["lanes"] = 1 access_egress_df = access_egress_df.rename( - columns=dict(zip(copy_cols, strip_ML_from_prop_list(copy_cols))) + columns=dict(zip(copy_cols, strip_ML_from_prop_list(copy_cols), strict=True)) ) # 5 - Add geometry diff --git a/network_wrangler/roadway/network.py b/network_wrangler/roadway/network.py index 15902a1b..56457ed0 100644 --- a/network_wrangler/roadway/network.py +++ b/network_wrangler/roadway/network.py @@ -21,16 +21,13 @@ import hashlib from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any import geopandas as gpd -import ijson import networkx as nx import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq from projectcard import ProjectCard, SubProject -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, ConfigDict, field_validator from ..configs import DefaultConfig, WranglerConfig, load_wrangler_config from ..errors import ( @@ -48,13 +45,11 @@ from ..params import DEFAULT_CATEGORY, DEFAULT_TIMESPAN, LAT_LON_CRS from ..utils.data import concat_with_attr from ..utils.models import empty_df_from_datamodel, validate_df_to_model -from .graph import net_to_graph from .links.create import data_to_links_df from .links.delete import delete_links_by_ids from .links.edit import edit_link_geometry_from_nodes from .links.filters import filter_links_to_ids, filter_links_to_node_ids from .links.links import node_ids_unique_to_link_ids, shape_ids_unique_to_link_ids -from .links.scopes import prop_for_scope from .model_roadway import ModelRoadwayNetwork from .nodes.create import data_to_nodes_df from .nodes.delete import delete_nodes_by_ids @@ -76,7 +71,6 @@ from .shapes.delete import delete_shapes_by_ids from .shapes.edit import edit_shape_geometry_from_nodes from .shapes.io import read_shapes -from .shapes.shapes import shape_ids_without_links if TYPE_CHECKING: from networkx import MultiDiGraph @@ -85,7 +79,7 @@ from ..transit.network import TransitNetwork -Selections = Union[RoadwayLinkSelection, RoadwayNodeSelection] +Selections = RoadwayLinkSelection | RoadwayNodeSelection class RoadwayNetwork(BaseModel): @@ -145,41 +139,61 @@ class RoadwayNetwork(BaseModel): network_hash: dynamic property of the hashed value of links_df and nodes_df. Used for quickly identifying if a network has changed since various expensive operations have taken place (i.e. generating a ModelRoadwayNetwork or a network graph) + _modification_version (int): counter that increments each time the network is modified. + Used for efficient change detection without expensive hash computation. model_net (ModelRoadwayNetwork): referenced `ModelRoadwayNetwork` object which will be - lazily created if None or if the `network_hash` has changed. + lazily created if None or if the network has been modified. config (WranglerConfig): wrangler configuration object """ - model_config = {"arbitrary_types_allowed": True} + model_config = ConfigDict(arbitrary_types_allowed=True) nodes_df: pd.DataFrame links_df: pd.DataFrame - _shapes_df: Optional[pd.DataFrame] = None + _shapes_df: pd.DataFrame | None = None - _links_file: Optional[Path] = None - _nodes_file: Optional[Path] = None - _shapes_file: Optional[Path] = None + _links_file: Path | None = None + _nodes_file: Path | None = None + _shapes_file: Path | None = None config: WranglerConfig = DefaultConfig _model_net: Optional[ModelRoadwayNetwork] = None + _model_net_version: int = -1 # Version when model_net was last created _selections: dict[str, Selections] = {} _modal_graphs: dict[str, dict] = defaultdict(lambda: {"graph": None, "hash": None}) + _modification_version: int = 0 # Incremented each time network is modified @field_validator("config") + @classmethod def validate_config(cls, v): """Validate config.""" return load_wrangler_config(v) - @field_validator("nodes_df", "links_df") - def coerce_crs(cls, v): - """Coerce crs of nodes_df and links_df to LAT_LON_CRS.""" - if v.crs != LAT_LON_CRS: + @field_validator("nodes_df", mode="before") + @classmethod + def validate_nodes_df(cls, v): + """Validate nodes_df to RoadNodesTable and coerce CRS.""" + v = validate_df_to_model(v, RoadNodesTable) + if hasattr(v, "crs") and v.crs != LAT_LON_CRS: + WranglerLogger.warning( + f"CRS of nodes_df ({v.crs}) doesn't match network crs {LAT_LON_CRS}. \ + Changing to network crs." + ) + v = v.to_crs(LAT_LON_CRS) + return v + + @field_validator("links_df", mode="before") + @classmethod + def validate_links_df(cls, v): + """Validate links_df to RoadLinksTable and coerce CRS.""" + v = validate_df_to_model(v, RoadLinksTable) + if hasattr(v, "crs") and v.crs != LAT_LON_CRS: WranglerLogger.warning( f"CRS of links_df ({v.crs}) doesn't match network crs {LAT_LON_CRS}. \ Changing to network crs." ) - v.to_crs(LAT_LON_CRS) + v = v.to_crs(LAT_LON_CRS) return v @property @@ -207,9 +221,33 @@ def shapes_df(self) -> pd.DataFrame: def shapes_df(self, value): self._shapes_df = df_to_shapes_df(value, config=self.config) + def _mark_modified(self) -> None: + """Mark the network as modified by incrementing the modification version. + + This should be called whenever the network data is modified to ensure + that dependent computations (selections, model networks, graphs) are + re-evaluated. Uses a simple version counter which is much faster than + computing hashes for change detection. + """ + self._modification_version += 1 + WranglerLogger.debug(f"Network modified. Version: {self._modification_version}") + + @property + def modification_version(self) -> int: + """Return the current modification version of the network. + + This counter increments each time the network is modified and can be used + for efficient change detection without computing expensive hashes. + """ + return self._modification_version + @property def network_hash(self) -> str: - """Hash of the links and nodes dataframes.""" + """Hash of the links and nodes dataframes. + + Note: This is an expensive operation. For change detection, prefer using + modification_version which is much faster. + """ _value = str.encode(self.links_df.df_hash() + "-" + self.nodes_df.df_hash()) _hash = hashlib.sha256(_value).hexdigest() @@ -217,9 +255,14 @@ def network_hash(self) -> str: @property def model_net(self) -> ModelRoadwayNetwork: - """Return a ModelRoadwayNetwork object for this network.""" - if self._model_net is None or self._model_net._net_hash != self.network_hash: + """Return a ModelRoadwayNetwork object for this network. + + The model network is lazily created and cached. It is invalidated when + the network's modification version changes. + """ + if self._model_net is None or self._model_net_version != self._modification_version: self._model_net = ModelRoadwayNetwork(self) + self._model_net_version = self._modification_version return self._model_net @property @@ -254,8 +297,8 @@ def link_shapes_df(self) -> gpd.GeoDataFrame: def get_property_by_timespan_and_group( self, link_property: str, - category: Optional[Union[str, int]] = DEFAULT_CATEGORY, - timespan: Optional[TimespanString] = DEFAULT_TIMESPAN, + category: str | int | None = DEFAULT_CATEGORY, + timespan: TimespanString | None = DEFAULT_TIMESPAN, strict_timespan_match: bool = False, min_overlap_minutes: int = 60, ) -> Any: @@ -273,7 +316,7 @@ def get_property_by_timespan_and_group( min_overlap_minutes: If strict_timespan_match is False, will return links that overlap with the timespan by at least this many minutes. Defaults to 60. """ - from .links.scopes import prop_for_scope # noqa: PLC0415 + from .links.scopes import prop_for_scope return prop_for_scope( self.links_df, @@ -286,9 +329,9 @@ def get_property_by_timespan_and_group( def get_selection( self, - selection_dict: Union[dict, SelectFacility], + selection_dict: dict | SelectFacility, overwrite: bool = False, - ) -> Union[RoadwayNodeSelection, RoadwayLinkSelection]: + ) -> RoadwayNodeSelection | RoadwayLinkSelection: """Return selection if it already exists, otherwise performs selection. Args: @@ -325,7 +368,11 @@ def get_selection( raise SelectionError(msg) def modal_graph_hash(self, mode) -> str: - """Hash of the links in order to detect a network change from when graph created.""" + """Hash of the links in order to detect a network change from when graph created. + + Note: This is an expensive operation. For internal change detection, + get_modal_graph uses modification_version instead. + """ _value = str.encode(self.links_df.df_hash() + "-" + mode) _hash = hashlib.sha256(_value).hexdigest() @@ -337,17 +384,20 @@ def get_modal_graph(self, mode) -> MultiDiGraph: Args: mode: mode of the network, one of `drive`,`transit`,`walk`, `bike` """ - from .graph import net_to_graph # noqa: PLC0415 + from .graph import net_to_graph - if self._modal_graphs[mode]["hash"] != self.modal_graph_hash(mode): + # Use modification version for efficient change detection + current_version = (self._modification_version, mode) + if self._modal_graphs[mode].get("version") != current_version: self._modal_graphs[mode]["graph"] = net_to_graph(self, mode) + self._modal_graphs[mode]["version"] = current_version return self._modal_graphs[mode]["graph"] def apply( self, - project_card: Union[ProjectCard, dict], - transit_net: Optional[TransitNetwork] = None, + project_card: ProjectCard | dict, + transit_net: TransitNetwork | None = None, **kwargs, ) -> RoadwayNetwork: """Wrapper method to apply a roadway project, returning a new RoadwayNetwork instance. @@ -359,7 +409,7 @@ def apply( skip anything related to transit network. **kwargs: keyword arguments to pass to project application """ - if not (isinstance(project_card, (ProjectCard, SubProject))): + if not (isinstance(project_card, ProjectCard | SubProject)): project_card = ProjectCard(project_card) # project_card.validate() @@ -377,8 +427,8 @@ def apply( def _apply_change( self, - change: Union[ProjectCard, SubProject], - transit_net: Optional[TransitNetwork] = None, + change: ProjectCard | SubProject, + transit_net: TransitNetwork | None = None, ) -> RoadwayNetwork: """Apply a single change: a single-project project or a sub-project.""" if not isinstance(change, SubProject): @@ -460,6 +510,7 @@ def add_links( self.links_df = validate_df_to_model( concat_with_attr([self.links_df, add_links_df], axis=0), RoadLinksTable ) + self._mark_modified() def add_nodes( self, @@ -488,6 +539,7 @@ def add_nodes( if self.nodes_df.attrs.get("name") != "road_nodes": msg = f"Expected nodes_df to have name 'road_nodes', got {self.nodes_df.attrs.get('name')}" raise NotNodesError(msg) + self._mark_modified() def add_shapes( self, @@ -515,13 +567,15 @@ def add_shapes( self.shapes_df = validate_df_to_model( concat_with_attr([self.shapes_df, add_shapes_df], axis=0), RoadShapesTable ) + # Note: shapes don't affect network_hash (only links and nodes), but we invalidate + # for consistency in case future changes include shapes in hash calculation def delete_links( self, - selection_dict: Union[dict, SelectLinksDict], + selection_dict: dict | SelectLinksDict, clean_nodes: bool = False, clean_shapes: bool = False, - transit_net: Optional[TransitNetwork] = None, + transit_net: TransitNetwork | None = None, ): """Deletes links based on selection dictionary and optionally associated nodes and shapes. @@ -575,10 +629,11 @@ def delete_links( ignore_missing=selection.ignore_missing, transit_net=transit_net, ) + self._mark_modified() def delete_nodes( self, - selection_dict: Union[dict, SelectNodesDict], + selection_dict: dict | SelectNodesDict, remove_links: bool = False, ) -> None: """Deletes nodes from roadway network. Wont delete nodes used by links in network. @@ -610,23 +665,26 @@ def delete_nodes( self.nodes_df = delete_nodes_by_ids( self.nodes_df, del_node_ids, ignore_missing=selection.ignore_missing ) + self._mark_modified() def clean_unused_shapes(self): """Removes any unused shapes from network that aren't referenced by links_df.""" - from .shapes.shapes import shape_ids_without_links # noqa: PLC0415 + from .shapes.shapes import shape_ids_without_links del_shape_ids = shape_ids_without_links(self.shapes_df, self.links_df) self.shapes_df = self.shapes_df.drop(del_shape_ids) + # Note: shapes don't affect network_hash, but invalidate for consistency def clean_unused_nodes(self): """Removes any unused nodes from network that aren't referenced by links_df. NOTE: does not check if these nodes are used by transit, so use with caution. """ - from .nodes.nodes import node_ids_without_links # noqa: PLC0415 + from .nodes.nodes import node_ids_without_links node_ids = node_ids_without_links(self.nodes_df, self.links_df) self.nodes_df = self.nodes_df.drop(node_ids) + self._mark_modified() def move_nodes( self, @@ -645,6 +703,7 @@ def move_nodes( self.shapes_df = edit_shape_geometry_from_nodes( self.shapes_df, self.links_df, self.nodes_df, node_ids ) + self._mark_modified() def has_node(self, model_node_id: int) -> bool: """Queries if network has node based on model_node_id. @@ -684,7 +743,7 @@ def is_connected(self, mode: str) -> bool: def add_incident_link_data_to_nodes( links_df: pd.DataFrame, nodes_df: pd.DataFrame, - link_variables: Optional[list] = None, + link_variables: list | None = None, ) -> pd.DataFrame: """Add data from links going to/from nodes to node. diff --git a/network_wrangler/roadway/nodes/create.py b/network_wrangler/roadway/nodes/create.py index 4cab8100..20b9f725 100644 --- a/network_wrangler/roadway/nodes/create.py +++ b/network_wrangler/roadway/nodes/create.py @@ -2,7 +2,6 @@ import copy import time -from typing import Union import geopandas as gpd import pandas as pd @@ -20,7 +19,7 @@ def _create_node_geometries_from_xy( - nodes_df: Union[pd.DataFrame, list[dict]], + nodes_df: pd.DataFrame | list[dict], in_crs: int = LAT_LON_CRS, net_crs: int = LAT_LON_CRS, ) -> gpd.GeoDataFrame: @@ -64,7 +63,7 @@ def _create_node_geometries_from_xy( @validate_call(config={"arbitrary_types_allowed": True}) def data_to_nodes_df( - nodes_df: Union[pd.DataFrame, gpd.GeoDataFrame, list[dict]], + nodes_df: pd.DataFrame | gpd.GeoDataFrame | list[dict], config: WranglerConfig = DefaultConfig, # noqa: ARG001 in_crs: int = LAT_LON_CRS, ) -> DataFrame[RoadNodesTable]: diff --git a/network_wrangler/roadway/nodes/edit.py b/network_wrangler/roadway/nodes/edit.py index 78d265df..d2b3738b 100644 --- a/network_wrangler/roadway/nodes/edit.py +++ b/network_wrangler/roadway/nodes/edit.py @@ -5,7 +5,6 @@ """ import copy -from typing import Optional, Union import geopandas as gpd from pandera import DataFrameModel, Field @@ -43,7 +42,7 @@ class NodeGeometryChange(RecordModel): model_config = ConfigDict(extra="ignore") X: float Y: float - in_crs: Optional[int] = LAT_LON_CRS + in_crs: int | None = LAT_LON_CRS @validate_call_pyd @@ -92,8 +91,8 @@ def edit_node_property( nodes_df: DataFrame[RoadNodesTable], node_idx: list[int], prop_name: str, - prop_change: Union[dict, RoadPropertyChange], - project_name: Optional[str] = None, + prop_change: dict | RoadPropertyChange, + project_name: str | None = None, config: WranglerConfig = DefaultConfig, _geometry_ok: bool = False, ) -> DataFrame[RoadNodesTable]: diff --git a/network_wrangler/roadway/nodes/filters.py b/network_wrangler/roadway/nodes/filters.py index 79c8320d..3d2aa4ab 100644 --- a/network_wrangler/roadway/nodes/filters.py +++ b/network_wrangler/roadway/nodes/filters.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from pandera.typing import DataFrame @@ -31,7 +31,7 @@ def filter_nodes_to_ids( def filter_nodes_to_link_ids( link_ids: list[int], links_df: DataFrame[RoadLinksTable], - nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + nodes_df: DataFrame[RoadNodesTable] | None = None, ) -> DataFrame[RoadNodesTable]: """Filters nodes dataframe to those used by given link_ids. diff --git a/network_wrangler/roadway/nodes/io.py b/network_wrangler/roadway/nodes/io.py index 4ef7a696..b9961a20 100644 --- a/network_wrangler/roadway/nodes/io.py +++ b/network_wrangler/roadway/nodes/io.py @@ -4,7 +4,7 @@ import time from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from geopandas import GeoDataFrame from pandera.typing import DataFrame @@ -28,9 +28,9 @@ def read_nodes( filename: Path, in_crs: int = LAT_LON_CRS, - boundary_gdf: Optional[GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, + boundary_gdf: GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, config: WranglerConfig = DefaultConfig, ) -> DataFrame[RoadNodesTable]: """Reads nodes and returns a geodataframe of nodes. @@ -100,7 +100,7 @@ def nodes_df_to_geojson(nodes_df: DataFrame[RoadNodesTable], properties: list[st @validate_call_pyd def write_nodes( nodes_df: DataFrame[RoadNodesTable], - out_dir: Union[str, Path], + out_dir: str | Path, prefix: str, file_format: GeoFileTypes = "geojson", overwrite: bool = True, @@ -121,9 +121,9 @@ def write_nodes( def get_nodes( - transit_net: Optional[TransitNetwork] = None, - roadway_net: Optional[RoadwayNetwork] = None, - roadway_path: Optional[Union[str, Path]] = None, + transit_net: TransitNetwork | None = None, + roadway_net: RoadwayNetwork | None = None, + roadway_path: str | Path | None = None, config: WranglerConfig = DefaultConfig, ) -> GeoDataFrame: """Get nodes from a transit network, roadway network, or roadway file. diff --git a/network_wrangler/roadway/nodes/validate.py b/network_wrangler/roadway/nodes/validate.py index 152725b7..d7c59012 100644 --- a/network_wrangler/roadway/nodes/validate.py +++ b/network_wrangler/roadway/nodes/validate.py @@ -5,7 +5,6 @@ from pathlib import Path import pandas as pd -import pyarrow as pa from ...logger import WranglerLogger from ...models.roadway.tables import RoadNodesTable @@ -47,7 +46,7 @@ def validate_nodes_df( is_valid = True if not strict: - from .create import data_to_nodes_df # noqa: PLC0415 + from .create import data_to_nodes_df try: nodes_df = data_to_nodes_df(nodes_df) diff --git a/network_wrangler/roadway/projects/add.py b/network_wrangler/roadway/projects/add.py index b087b7e5..c5438aec 100644 --- a/network_wrangler/roadway/projects/add.py +++ b/network_wrangler/roadway/projects/add.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pandas as pd @@ -18,7 +18,7 @@ def apply_new_roadway( roadway_net: RoadwayNetwork, roadway_addition: dict, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> RoadwayNetwork: """Add the new roadway features defined in the project card. diff --git a/network_wrangler/roadway/projects/calculate.py b/network_wrangler/roadway/projects/calculate.py index 73a5b8a9..e518311e 100644 --- a/network_wrangler/roadway/projects/calculate.py +++ b/network_wrangler/roadway/projects/calculate.py @@ -23,5 +23,7 @@ def apply_calculated_roadway( WranglerLogger.debug("Applying calculated roadway project.") self = roadway_net exec(pycode) + # Invalidate network state since exec'd code may have modified links_df or nodes_df + roadway_net._mark_modified() return roadway_net diff --git a/network_wrangler/roadway/projects/delete.py b/network_wrangler/roadway/projects/delete.py index c9968f4f..c536a6ce 100644 --- a/network_wrangler/roadway/projects/delete.py +++ b/network_wrangler/roadway/projects/delete.py @@ -2,9 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union - -import pandas as pd +from typing import TYPE_CHECKING from ...logger import WranglerLogger from ...models.projects.roadway_changes import RoadwayDeletion @@ -16,8 +14,8 @@ def apply_roadway_deletion( roadway_net: RoadwayNetwork, - roadway_deletion: Union[dict, RoadwayDeletion], - transit_net: Optional[TransitNetwork] = None, + roadway_deletion: dict | RoadwayDeletion, + transit_net: TransitNetwork | None = None, ) -> RoadwayNetwork: """Delete the roadway links or nodes defined in the project card. diff --git a/network_wrangler/roadway/projects/edit_property.py b/network_wrangler/roadway/projects/edit_property.py index 6f9781b2..9acda2de 100644 --- a/network_wrangler/roadway/projects/edit_property.py +++ b/network_wrangler/roadway/projects/edit_property.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import pandas as pd @@ -21,7 +21,7 @@ def _node_geo_change_from_property_changes( property_changes: dict[str, RoadPropertyChange], node_idx: list[int], -) -> Union[None, NodeGeometryChangeTable]: +) -> None | NodeGeometryChangeTable: """Return NodeGeometryChangeTable if property_changes includes gometry change else None.""" geo_change_present = any(f in property_changes for f in ["X", "Y"]) if not geo_change_present: @@ -48,9 +48,9 @@ def _node_geo_change_from_property_changes( def apply_roadway_property_change( roadway_net: RoadwayNetwork, - selection: Union[RoadwayNodeSelection, RoadwayLinkSelection], + selection: RoadwayNodeSelection | RoadwayLinkSelection, property_changes: dict[str, RoadPropertyChange], - project_name: Optional[str] = None, + project_name: str | None = None, ) -> RoadwayNetwork: """Changes roadway properties for the selected features based on the project card. @@ -79,6 +79,7 @@ def apply_roadway_property_change( property_changes, project_name=project_name, ) + roadway_net._mark_modified() elif isinstance(selection, RoadwayNodeSelection): non_geo_changes = { @@ -93,11 +94,14 @@ def apply_roadway_property_change( prop_change, project_name=project_name, ) + if non_geo_changes: + roadway_net._mark_modified() geo_changes_df = _node_geo_change_from_property_changes( property_changes, selection.selected_nodes ) if geo_changes_df is not None: + # move_nodes already calls _mark_modified() roadway_net.move_nodes(geo_changes_df) else: diff --git a/network_wrangler/roadway/segment.py b/network_wrangler/roadway/segment.py index b734e51f..ad8def6a 100644 --- a/network_wrangler/roadway/segment.py +++ b/network_wrangler/roadway/segment.py @@ -23,16 +23,15 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import numpy as np import pandas as pd -import pyarrow as pa from pandera.typing import DataFrame from ..errors import SegmentFormatError, SegmentSelectionError, SubnetCreationError from ..logger import WranglerLogger -from ..models.projects.roadway_selection import SelectLinksDict, SelectNodeDict +from ..models.projects.roadway_selection import SelectNodeDict from ..params import DEFAULT_SEARCH_MODES from .graph import shortest_path from .links.filters import filter_links_to_path @@ -122,11 +121,11 @@ def __init__( self.selection = selection # segment members are identified by storing nodes along a route - self._segment_nodes: Union[list, None] = None + self._segment_nodes: list | None = None # Initialize calculated, read-only attr. - self._from_node_id: Union[int, None] = None - self._to_node_id: Union[int, None] = None + self._from_node_id: int | None = None + self._to_node_id: int | None = None self.subnet = self._generate_subnet(self.segment_sel_dict) @@ -396,7 +395,7 @@ def identify_segment_endpoints( _links_df, ) ) - from .network import add_incident_link_data_to_nodes # noqa: PLC0415 + from .network import add_incident_link_data_to_nodes _nodes_df = add_incident_link_data_to_nodes( links_df=_links_df, diff --git a/network_wrangler/roadway/selection.py b/network_wrangler/roadway/selection.py index 239e26e4..9b7359bb 100644 --- a/network_wrangler/roadway/selection.py +++ b/network_wrangler/roadway/selection.py @@ -2,12 +2,9 @@ from __future__ import annotations -import copy import hashlib from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, ClassVar, Literal, Union - -import pandas as pd +from typing import TYPE_CHECKING, ClassVar, Literal from ..errors import SelectionError from ..logger import WranglerLogger @@ -70,7 +67,7 @@ class RoadwaySelection(ABC): def __init__( self, net: RoadwayNetwork, - selection_data: Union[SelectFacility, dict], + selection_data: SelectFacility | dict, ): """Constructor for RoadwaySelection object. @@ -129,7 +126,7 @@ def selection_dict(self) -> dict: return self._selection_dict @selection_dict.setter - def selection_dict(self, selection_input: Union[SelectFacility, dict]): + def selection_dict(self, selection_input: SelectFacility | dict): if isinstance(selection_input, SelectLinksDict): selection_input = SelectFacility(links=selection_input) elif isinstance(selection_input, SelectNodesDict): @@ -146,7 +143,7 @@ def selection_dict(self, selection_input: Union[SelectFacility, dict]): self._selection_data = self.validate_selection(selection_input) self._selection_dict = self._selection_data.asdict - self._stored_net_hash = copy.deepcopy(self.net.network_hash) + self._stored_net_version = self.net.modification_version @property def node_query_fields(self) -> list[str]: @@ -214,7 +211,7 @@ class RoadwayLinkSelection(RoadwaySelection): def __init__( self, net: RoadwayNetwork, - selection_data: Union[SelectFacility, dict], + selection_data: SelectFacility | dict, ): """Constructor for RoadwayLinkSelection object. @@ -224,8 +221,8 @@ def __init__( `SelectFacility` model with a "links" key or SelectFacility instance. """ super().__init__(net, selection_data) - self._selected_links_df: Union[None, DataFrame[RoadLinksTable]] = None - self._segment: Union[None, Segment] = None + self._selected_links_df: None | DataFrame[RoadLinksTable] = None + self._segment: None | Segment = None WranglerLogger.debug(f"Created LinkSelection of type: {self.selection_method}") def __nonzero__(self) -> bool: @@ -323,7 +320,7 @@ def found(self) -> bool: return self.selected_links_df is not None @property - def segment(self) -> Union[None, Segment]: + def segment(self) -> None | Segment: """Return the segment object if selection type is segment.""" if self._segment is None and self.selection_method == "segment": WranglerLogger.debug("Creating new segment") @@ -339,11 +336,14 @@ def create_segment(self, max_search_breadth: int): def selected_links_df(self) -> DataFrame[RoadLinksTable]: """Lazily evaluates selection for links or returns stored value in self._selected_links_df. - Will re-evaluate if the current network hash is different than the stored one from the - last selection. + Will re-evaluate if the current network modification version is different than the stored + one from the last selection. """ - if self._selected_links_df is None or self._stored_net_hash != self.net.network_hash: - self._stored_net_hash = copy.deepcopy(self.net.network_hash) + if ( + self._selected_links_df is None + or self._stored_net_version != self.net.modification_version + ): + self._stored_net_version = self.net.modification_version self._selected_links_df = self._perform_selection() return self._selected_links_df @@ -436,7 +436,7 @@ class RoadwayNodeSelection(RoadwaySelection): def __init__( self, net: RoadwayNetwork, - selection_data: Union[dict, SelectFacility], + selection_data: dict | SelectFacility, ): """Constructor for RoadwayNodeSelection object. @@ -446,7 +446,7 @@ def __init__( conforming to SelectFacility format, or SelectFacility instance. """ super().__init__(net, selection_data) - self._selected_nodes_df: Union[None, DataFrame[RoadNodesTable]] = None + self._selected_nodes_df: None | DataFrame[RoadNodesTable] = None def __nonzero__(self) -> bool: """Return True if nodes were selected.""" @@ -539,11 +539,14 @@ def found(self) -> bool: def selected_nodes_df(self) -> DataFrame[RoadNodesTable]: """Lazily evaluates selection for nodes or returns stored value in self._selected_nodes_df. - Will re-evaluate if the current network hash is different than the stored one from the - last selection. + Will re-evaluate if the current network modification version is different than the stored + one from the last selection. """ - if self._selected_nodes_df is None or self._stored_net_hash != self.net.network_hash: - self._stored_net_hash = self.net.network_hash + if ( + self._selected_nodes_df is None + or self._stored_net_version != self.net.modification_version + ): + self._stored_net_version = self.net.modification_version self._selected_nodes_df = self._perform_selection() return self._selected_nodes_df @@ -599,7 +602,7 @@ def _perform_selection(self): def _create_selection_key( - selection_dict: Union[SelectLinksDict, SelectNodesDict, SelectFacility, dict], + selection_dict: SelectLinksDict | SelectNodesDict | SelectFacility | dict, ) -> str: """Selections are stored by a sha1 hash of the bit-encoded string of the selection dictionary. diff --git a/network_wrangler/roadway/shapes/io.py b/network_wrangler/roadway/shapes/io.py index df71947a..2f56febc 100644 --- a/network_wrangler/roadway/shapes/io.py +++ b/network_wrangler/roadway/shapes/io.py @@ -4,7 +4,6 @@ import time from pathlib import Path -from typing import Optional, Union from geopandas import GeoDataFrame from pandera.typing import DataFrame @@ -27,10 +26,10 @@ def read_shapes( filename: Path, in_crs: int = LAT_LON_CRS, - boundary_gdf: Optional[GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, - filter_to_shape_ids: Optional[list] = None, + boundary_gdf: GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, + filter_to_shape_ids: list | None = None, config: WranglerConfig = DefaultConfig, ) -> DataFrame[RoadShapesTable]: """Reads shapes and returns a geodataframe of shapes if filename is found. @@ -88,7 +87,7 @@ def read_shapes( @validate_call_pyd def write_shapes( shapes_df: DataFrame[RoadShapesTable], - out_dir: Union[str, Path], + out_dir: str | Path, prefix: str, format: str, overwrite: bool, diff --git a/network_wrangler/roadway/shapes/validate.py b/network_wrangler/roadway/shapes/validate.py index b1db1b0f..a060b4b6 100644 --- a/network_wrangler/roadway/shapes/validate.py +++ b/network_wrangler/roadway/shapes/validate.py @@ -5,7 +5,6 @@ from pathlib import Path import pandas as pd -import pyarrow as pa from ...logger import WranglerLogger from ...models.roadway.tables import RoadShapesTable @@ -47,7 +46,7 @@ def validate_shapes_df( is_valid = True if not strict: - from .create import df_to_shapes_df # noqa: PLC0415 + from .create import df_to_shapes_df try: shapes_df = df_to_shapes_df(shapes_df) diff --git a/network_wrangler/roadway/subnet.py b/network_wrangler/roadway/subnet.py index 36882e29..22d6ea1b 100644 --- a/network_wrangler/roadway/subnet.py +++ b/network_wrangler/roadway/subnet.py @@ -3,7 +3,7 @@ from __future__ import annotations import hashlib -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pandas as pd from pandera.typing import DataFrame @@ -70,7 +70,7 @@ class Subnet: def __init__( self, net: RoadwayNetwork, - modes: Optional[list] = DEFAULT_SEARCH_MODES, + modes: list | None = DEFAULT_SEARCH_MODES, subnet_links_df: pd.DataFrame = None, i: int = 0, sp_weight_factor: float = DEFAULT_SUBNET_SP_WEIGHT_FACTOR, diff --git a/network_wrangler/roadway/utils.py b/network_wrangler/roadway/utils.py index 1d42c9fb..5a2b9950 100644 --- a/network_wrangler/roadway/utils.py +++ b/network_wrangler/roadway/utils.py @@ -3,7 +3,7 @@ from __future__ import annotations import hashlib -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import pandas as pd @@ -18,8 +18,8 @@ def compare_networks( - nets: list[Union[RoadwayNetwork, ModelRoadwayNetwork]], - names: Optional[list[str]] = None, + nets: list[RoadwayNetwork | ModelRoadwayNetwork], + names: list[str] | None = None, ) -> pd.DataFrame: """Compare the summary of networks in a list of networks. @@ -29,13 +29,13 @@ def compare_networks( """ if names is None: names = ["net" + str(i) for i in range(1, len(nets) + 1)] - df = pd.DataFrame({name: net.summary for name, net in zip(names, nets)}) + df = pd.DataFrame({name: net.summary for name, net in zip(names, nets, strict=True)}) return df def compare_links( links: list[pd.DataFrame], - names: Optional[list[str]] = None, + names: list[str] | None = None, ) -> pd.DataFrame: """Compare the summary of links in a list of dataframes. @@ -45,7 +45,9 @@ def compare_links( """ if names is None: names = ["links" + str(i) for i in range(1, len(links) + 1)] - df = pd.DataFrame({name: link.of_type.summary for name, link in zip(names, links)}) + df = pd.DataFrame( + {name: link.of_type.summary for name, link in zip(names, links, strict=True)} + ) return df diff --git a/network_wrangler/roadway/validate.py b/network_wrangler/roadway/validate.py index 9dec8ecb..89c31640 100644 --- a/network_wrangler/roadway/validate.py +++ b/network_wrangler/roadway/validate.py @@ -1,7 +1,6 @@ """Validates a roadway network to the wrangler data model specifications.""" from pathlib import Path -from typing import Optional from ..logger import WranglerLogger from ..models._base.types import RoadwayFileTypes @@ -37,7 +36,7 @@ def validate_roadway_in_dir( def validate_roadway_files( links_file: Path, nodes_file: Path, - shapes_file: Optional[Path] = None, + shapes_file: Path | None = None, strict: bool = False, output_dir: Path = Path(), ): diff --git a/network_wrangler/scenario.py b/network_wrangler/scenario.py index 819d8294..536c760e 100644 --- a/network_wrangler/scenario.py +++ b/network_wrangler/scenario.py @@ -159,7 +159,7 @@ from collections import defaultdict, deque from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import yaml from projectcard import ProjectCard, SubProject, read_cards, write_card @@ -291,9 +291,9 @@ class Scenario: def __init__( self, - base_scenario: Union[Scenario, dict], - project_card_list: Optional[list[ProjectCard]] = None, - config: Optional[Union[WranglerConfig, dict, Path, list[Path]]] = None, + base_scenario: Scenario | dict, + project_card_list: list[ProjectCard] | None = None, + config: WranglerConfig | dict | Path | list[Path] | None = None, name: str = "", ): """Constructor. @@ -327,11 +327,9 @@ def __init__( ) self.name: str = name # if the base scenario had roadway or transit networks, use them as the basis. - self.road_net: Optional[RoadwayNetwork] = copy.deepcopy( - base_scenario.pop("road_net", None) - ) + self.road_net: RoadwayNetwork | None = copy.deepcopy(base_scenario.pop("road_net", None)) - self.transit_net: Optional[TransitNetwork] = copy.deepcopy( + self.transit_net: TransitNetwork | None = copy.deepcopy( base_scenario.pop("transit_net", None) ) if self.road_net and self.transit_net: @@ -402,7 +400,7 @@ def _add_project( self, project_card: ProjectCard, validate: bool = True, - filter_tags: Optional[list[str]] = None, + filter_tags: list[str] | None = None, ) -> None: """Adds a single ProjectCard instances to the Scenario. @@ -448,7 +446,7 @@ def add_project_cards( self, project_card_list: list[ProjectCard], validate: bool = True, - filter_tags: Optional[list[str]] = None, + filter_tags: list[str] | None = None, ) -> None: """Adds a list of ProjectCard instances to the Scenario. @@ -620,7 +618,7 @@ def apply_all_projects(self): # set this so it will trigger re-queuing any more projects. self._queued_projects = None - def _apply_change(self, change: Union[ProjectCard, SubProject]) -> None: + def _apply_change(self, change: ProjectCard | SubProject) -> None: """Applies a specific change specified in a project card. Change type must be in at least one of: @@ -708,14 +706,14 @@ def write( transit_write: bool = True, projects_write: bool = True, roadway_convert_complex_link_properties_to_single_field: bool = False, - roadway_out_dir: Optional[Path] = None, - roadway_prefix: Optional[str] = None, + roadway_out_dir: Path | None = None, + roadway_prefix: str | None = None, roadway_file_format: RoadwayFileTypes = "parquet", roadway_true_shape: bool = False, - transit_out_dir: Optional[Path] = None, - transit_prefix: Optional[str] = None, + transit_out_dir: Path | None = None, + transit_prefix: str | None = None, transit_file_format: TransitFileTypes = "txt", - projects_out_dir: Optional[Path] = None, + projects_out_dir: Path | None = None, ) -> Path: """Writes scenario networks and summary to disk and returns path to scenario file. @@ -817,12 +815,12 @@ def summary(self) -> dict: def create_scenario( - base_scenario: Optional[Union[Scenario, dict]] = None, + base_scenario: Scenario | dict | None = None, name: str = datetime.now().strftime("%Y%m%d%H%M%S"), project_card_list=None, - project_card_filepath: Optional[Union[list[Path], Path]] = None, - filter_tags: Optional[list[str]] = None, - config: Optional[Union[dict, Path, list[Path], WranglerConfig]] = None, + project_card_filepath: list[Path] | Path | None = None, + filter_tags: list[str] | None = None, + config: dict | Path | list[Path] | WranglerConfig | None = None, ) -> Scenario: """Creates scenario from a base scenario and adds project cards. @@ -889,7 +887,7 @@ def write_applied_projects(scenario: Scenario, out_dir: Path, overwrite: bool = def load_scenario( - scenario_data: Union[dict, Path], + scenario_data: dict | Path, name: str = datetime.now().strftime("%Y%m%d%H%M%S"), ) -> Scenario: """Loads a scenario from a file written by Scenario.write() as the base scenario. @@ -920,10 +918,10 @@ def load_scenario( def create_base_scenario( - roadway: Optional[dict] = None, - transit: Optional[dict] = None, - applied_projects: Optional[list] = None, - conflicts: Optional[dict] = None, + roadway: dict | None = None, + transit: dict | None = None, + applied_projects: list | None = None, + conflicts: dict | None = None, config: WranglerConfig = DefaultConfig, ) -> dict: """Creates a base scenario dictionary from roadway and transit network files. @@ -969,7 +967,7 @@ def create_base_scenario( def _load_base_scenario_from_config( - base_scenario_data: Union[dict, ScenarioInputConfig], config: WranglerConfig = DefaultConfig + base_scenario_data: dict | ScenarioInputConfig, config: WranglerConfig = DefaultConfig ) -> dict: """Loads a scenario from a file written by Scenario.write() as the base scenario. @@ -1008,7 +1006,7 @@ def extract_base_scenario_metadata(base_scenario: dict) -> dict: def build_scenario_from_config( - scenario_config: Union[Path, list[Path], ScenarioConfig, dict], + scenario_config: Path | list[Path] | ScenarioConfig | dict, ) -> Scenario: """Builds a scenario from a dictionary configuration. diff --git a/network_wrangler/transit/clip.py b/network_wrangler/transit/clip.py index 6ccd685e..90ddf2e0 100644 --- a/network_wrangler/transit/clip.py +++ b/network_wrangler/transit/clip.py @@ -19,7 +19,6 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, Union import geopandas as gpd import pandas as pd @@ -134,9 +133,9 @@ def _remove_links_from_feed( def clip_feed_to_boundary( feed: Feed, ref_nodes_df: gpd.GeoDataFrame, - boundary_gdf: Optional[gpd.GeoDataFrame] = None, - boundary_geocode: Optional[Union[str, dict]] = None, - boundary_file: Optional[Union[str, Path]] = None, + boundary_gdf: gpd.GeoDataFrame | None = None, + boundary_geocode: str | dict | None = None, + boundary_file: str | Path | None = None, min_stops: int = DEFAULT_MIN_STOPS, ) -> Feed: """Clips a transit Feed object to a boundary and returns the resulting GeoDataFrames. @@ -234,13 +233,13 @@ def _clip_feed_to_nodes( def clip_transit( - network: Union[TransitNetwork, str, Path], - node_ids: Optional[Union[None, list[str]]] = None, - boundary_geocode: Optional[Union[str, dict, None]] = None, - boundary_file: Optional[Union[str, Path]] = None, - boundary_gdf: Optional[Union[None, gpd.GeoDataFrame]] = None, - ref_nodes_df: Optional[Union[None, gpd.GeoDataFrame]] = None, - roadway_net: Optional[Union[None, RoadwayNetwork]] = None, + network: TransitNetwork | str | Path, + node_ids: None | list[str] = None, + boundary_geocode: None | str | dict = None, + boundary_file: str | Path | None = None, + boundary_gdf: None | gpd.GeoDataFrame = None, + ref_nodes_df: None | gpd.GeoDataFrame = None, + roadway_net: None | RoadwayNetwork = None, min_stops: int = DEFAULT_MIN_STOPS, ) -> TransitNetwork: """Returns a new TransitNetwork clipped to a boundary as determined by arguments. diff --git a/network_wrangler/transit/feed/feed.py b/network_wrangler/transit/feed/feed.py index 4ae5a0b8..da99083a 100644 --- a/network_wrangler/transit/feed/feed.py +++ b/network_wrangler/transit/feed/feed.py @@ -2,8 +2,9 @@ from __future__ import annotations +from collections.abc import Callable from pathlib import Path -from typing import Callable, ClassVar, Literal, Optional +from typing import ClassVar, Literal import pandas as pd from pandera.typing import DataFrame @@ -87,7 +88,7 @@ def __init__(self, **kwargs): extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names} if extra_attr: WranglerLogger.info(f"Adding additional attributes to Feed: {extra_attr.keys()}") - for k, v in extra_attr: + for k, v in extra_attr.items(): self.__setattr__(k, v) def set_by_id( @@ -95,7 +96,7 @@ def set_by_id( table_name: str, set_df: pd.DataFrame, id_property: str = "index", - properties: Optional[list[str]] = None, + properties: list[str] | None = None, ): """Set one or more property values based on an ID property for a given table. diff --git a/network_wrangler/transit/feed/shapes.py b/network_wrangler/transit/feed/shapes.py index 2bdb318d..d9e30871 100644 --- a/network_wrangler/transit/feed/shapes.py +++ b/network_wrangler/transit/feed/shapes.py @@ -2,10 +2,7 @@ from __future__ import annotations -import ijson import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq from pandera.typing import DataFrame from ...logger import WranglerLogger @@ -75,7 +72,7 @@ def shapes_with_stop_id_for_trip_id( "pickup_only": only pickup > 0 "dropoff_only": only dropoff > 0 """ - from .stop_times import stop_times_for_pickup_dropoff_trip_id # noqa: PLC0415 + from .stop_times import stop_times_for_pickup_dropoff_trip_id shapes = shapes_for_trip_id(shapes, trips, trip_id) trip_stop_times = stop_times_for_pickup_dropoff_trip_id( @@ -123,7 +120,7 @@ def shapes_with_stops_for_shape_id( Returns: DataFrame[WranglerShapesTable]: DataFrame containing shapes with associated stops. """ - from .trips import trip_ids_for_shape_id # noqa: PLC0415 + from .trips import trip_ids_for_shape_id trip_ids = trip_ids_for_shape_id(trips, shape_id) all_shape_stop_times = concat_with_attr( diff --git a/network_wrangler/transit/feed/stops.py b/network_wrangler/transit/feed/stops.py index 43d2ee0a..226e234e 100644 --- a/network_wrangler/transit/feed/stops.py +++ b/network_wrangler/transit/feed/stops.py @@ -2,9 +2,6 @@ from __future__ import annotations -from typing import Union - -import pyarrow as pa from pandera.typing import DataFrame from ...logger import WranglerLogger @@ -31,7 +28,7 @@ def stop_id_pattern_for_trip( "pickup_only": only pickup > 0 "dropoff_only": only dropoff > 0 """ - from .stop_times import stop_times_for_pickup_dropoff_trip_id # noqa: PLC0415 + from .stop_times import stop_times_for_pickup_dropoff_trip_id trip_stops = stop_times_for_pickup_dropoff_trip_id( stop_times, trip_id, pickup_dropoff=pickup_dropoff @@ -66,10 +63,10 @@ def stops_for_trip_id( def node_is_stop( stops: DataFrame[WranglerStopsTable], stop_times: DataFrame[WranglerStopTimesTable], - node_id: Union[int, list[int]], + node_id: int | list[int], trip_id: str, pickup_dropoff: PickupDropoffAvailability = "either", -) -> Union[bool, list[bool]]: +) -> bool | list[bool]: """Returns boolean indicating if a (or list of) node(s)) is (are) stops for a given trip_id. Args: diff --git a/network_wrangler/transit/feed/transit_links.py b/network_wrangler/transit/feed/transit_links.py index bd51f325..7a377570 100644 --- a/network_wrangler/transit/feed/transit_links.py +++ b/network_wrangler/transit/feed/transit_links.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Union - import pandas as pd from pandera.typing import DataFrame @@ -49,7 +47,7 @@ def unique_shape_links( shape_links = shapes_to_shape_links(shapes) # WranglerLogger.debug(f"Shape links: \n {shape_links[['shape_id', from_field, to_field]]}") - _agg_dict: dict[str, Union[type, str]] = {"shape_id": list} + _agg_dict: dict[str, type | str] = {"shape_id": list} _opt_fields = [f"shape_pt_{v}_{t}" for v in ["lat", "lon"] for t in [from_field, to_field]] for f in _opt_fields: if f in shape_links: diff --git a/network_wrangler/transit/geo.py b/network_wrangler/transit/geo.py index abaede95..7ebbed30 100644 --- a/network_wrangler/transit/geo.py +++ b/network_wrangler/transit/geo.py @@ -2,10 +2,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import geopandas as gpd -import pyarrow as pa from pandera.typing import DataFrame from shapely import LineString @@ -25,7 +24,7 @@ def shapes_to_trip_shapes_gdf( shapes: DataFrame[WranglerShapesTable], # trips: WranglerTripsTable, - ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + ref_nodes_df: DataFrame[RoadNodesTable] | None = None, crs: int = LAT_LON_CRS, ) -> gpd.GeoDataFrame: """Geodataframe with one polyline shape per shape_id. @@ -46,7 +45,7 @@ def shapes_to_trip_shapes_gdf( shapes[["shape_id", "shape_pt_lat", "shape_pt_lon"]] .groupby("shape_id") .agg(list) - .apply(lambda x: LineString(zip(x[1], x[0])), axis=1) + .apply(lambda x: LineString(zip(x[1], x[0], strict=True)), axis=1) ) route_shapes_gdf = gpd.GeoDataFrame( @@ -86,7 +85,7 @@ def update_shapes_geometry( def shapes_to_shape_links_gdf( shapes: DataFrame[WranglerShapesTable], - ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + ref_nodes_df: DataFrame[RoadNodesTable] | None = None, from_field: str = "A", to_field: str = "B", crs: int = LAT_LON_CRS, @@ -125,7 +124,7 @@ def shapes_to_shape_links_gdf( def stop_times_to_stop_time_points_gdf( stop_times: DataFrame[WranglerStopTimesTable], stops: DataFrame[WranglerStopsTable], - ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + ref_nodes_df: DataFrame[RoadNodesTable] | None = None, ) -> gpd.GeoDataFrame: """Stoptimes geodataframe as points using geometry from stops.txt or optionally another df. @@ -154,7 +153,7 @@ def stop_times_to_stop_time_points_gdf( def stop_times_to_stop_time_links_gdf( stop_times: DataFrame[WranglerStopTimesTable], stops: DataFrame[WranglerStopsTable], - ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None, + ref_nodes_df: DataFrame[RoadNodesTable] | None = None, from_field: str = "A", to_field: str = "B", ) -> gpd.GeoDataFrame: @@ -168,7 +167,7 @@ def stop_times_to_stop_time_links_gdf( from_field: Field used for the link's from node `model_node_id`. Defaults to "A". to_field: Field used for the link's to node `model_node_id`. Defaults to "B". """ - from ..utils.geo import linestring_from_lats_lons # noqa: PLC0415 + from ..utils.geo import linestring_from_lats_lons if ref_nodes_df is not None: stops = update_stops_geometry(stops, ref_nodes_df) diff --git a/network_wrangler/transit/io.py b/network_wrangler/transit/io.py index 88c960ad..eed70550 100644 --- a/network_wrangler/transit/io.py +++ b/network_wrangler/transit/io.py @@ -1,11 +1,10 @@ """Functions for reading and writing transit feeds and networks.""" from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal import geopandas as gpd import pandas as pd -import pyarrow as pa from ..configs import DefaultConfig, WranglerConfig from ..errors import FeedReadError @@ -29,9 +28,7 @@ def _feed_path_ref(path: Path) -> Path: return path -def load_feed_from_path( - feed_path: Union[Path, str], file_format: TransitFileTypes = "txt" -) -> Feed: +def load_feed_from_path(feed_path: Path | str, file_format: TransitFileTypes = "txt") -> Feed: """Create a Feed object from the path to a GTFS transit feed. Args: @@ -120,7 +117,7 @@ def load_feed_from_dfs(feed_dfs: dict) -> Feed: def load_transit( - feed: Union[Feed, GtfsModel, dict[str, pd.DataFrame], str, Path], + feed: Feed | GtfsModel | dict[str, pd.DataFrame] | str | Path, file_format: TransitFileTypes = "txt", config: WranglerConfig = DefaultConfig, ) -> "TransitNetwork": @@ -155,7 +152,7 @@ def load_transit( ``` """ - if isinstance(feed, (Path, str)): + if isinstance(feed, Path | str): feed = Path(feed) feed_obj = load_feed_from_path(feed, file_format=file_format) feed_obj.feed_path = feed @@ -174,8 +171,8 @@ def load_transit( def write_transit( transit_net, - out_dir: Union[Path, str] = ".", - prefix: Optional[Union[Path, str]] = None, + out_dir: Path | str = ".", + prefix: Path | str | None = None, file_format: Literal["txt", "csv", "parquet"] = "txt", overwrite: bool = True, ) -> None: @@ -199,9 +196,9 @@ def write_transit( def convert_transit_serialization( - input_path: Union[str, Path], + input_path: str | Path, output_format: TransitFileTypes, - out_dir: Union[Path, str] = ".", + out_dir: Path | str = ".", input_file_format: TransitFileTypes = "csv", out_prefix: str = "", overwrite: bool = True, @@ -234,7 +231,7 @@ def convert_transit_serialization( def write_feed_geo( feed: Feed, ref_nodes_df: gpd.GeoDataFrame, - out_dir: Union[str, Path], + out_dir: str | Path, file_format: Literal["geojson", "shp", "parquet"] = "geojson", out_prefix=None, overwrite: bool = True, @@ -249,7 +246,7 @@ def write_feed_geo( out_prefix: prefix to add to the file name overwrite: if True, will overwrite the files if they already exist. Defaults to True """ - from .geo import shapes_to_shape_links_gdf # noqa: PLC0415 + from .geo import shapes_to_shape_links_gdf out_dir = Path(out_dir) if not out_dir.is_dir(): diff --git a/network_wrangler/transit/model_transit.py b/network_wrangler/transit/model_transit.py index bebd0bed..17c1b182 100644 --- a/network_wrangler/transit/model_transit.py +++ b/network_wrangler/transit/model_transit.py @@ -26,8 +26,8 @@ def __init__( """ModelTransit class for managing consistency between roadway and transit networks.""" self.transit_net = transit_net self.roadway_net = roadway_net - self._roadway_net_hash = None - self._transit_feed_hash = None + self._roadway_net_version = None + self._transit_feed_version = None self._transit_shifted_to_ML = shift_transit_to_managed_lanes @property @@ -39,8 +39,8 @@ def model_roadway_net(self): def consistent_nets(self) -> bool: """Indicate if roadway and transit networks have changed since self.m_feed updated.""" return bool( - self.roadway_net.network_hash == self._roadway_net_hash - and self.transit_net.feed_hash == self._transit_feed_hash + self.roadway_net.modification_version == self._roadway_net_version + and self.transit_net.feed.modification_version == self._transit_feed_version ) @property @@ -49,9 +49,9 @@ def m_feed(self): if self.consistent_nets: return self._m_feed # NOTE: look at this - # If netoworks have changed, updated model transit and update reference hash - self._roadway_net_hash = copy.deepcopy(self.roadway_net.network_hash) - self._transit_feed_hash = copy.deepcopy(self.transit_net.feed_hash) + # If networks have changed, update model transit and update reference version + self._roadway_net_version = self.roadway_net.modification_version + self._transit_feed_version = self.transit_net.feed.modification_version if not self._transit_shifted_to_ML: self._m_feed = copy.deepcopy(self.transit_net.feed) diff --git a/network_wrangler/transit/network.py b/network_wrangler/transit/network.py index cd2f52c6..4d7eb4bb 100644 --- a/network_wrangler/transit/network.py +++ b/network_wrangler/transit/network.py @@ -19,7 +19,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import TYPE_CHECKING, ClassVar import geopandas as gpd import networkx as nx @@ -88,7 +88,7 @@ def __init__(self, feed: Feed, config: WranglerConfig = DefaultConfig) -> None: """ WranglerLogger.debug("Creating new TransitNetwork.") - self._road_net: Optional[RoadwayNetwork] = None + self._road_net: RoadwayNetwork | None = None self.feed: Feed = feed self.graph: nx.MultiDiGraph = None self.config: WranglerConfig = config @@ -126,14 +126,14 @@ def feed(self, feed: Feed): raise TransitValidationError(msg) if self._road_net is None or transit_road_net_consistency(feed, self._road_net): self._feed = feed - self._stored_feed_hash = copy.deepcopy(feed.hash) + self._stored_feed_version = feed.modification_version else: msg = "Can't assign Feed inconsistent with set Roadway Network." WranglerLogger.error(msg) raise TransitRoadwayConsistencyError(msg) @property - def road_net(self) -> Union[None, RoadwayNetwork]: + def road_net(self) -> None | RoadwayNetwork: """Roadway network associated with the transit network.""" return self._road_net @@ -146,7 +146,7 @@ def road_net(self, road_net_in: RoadwayNetwork): raise TransitValidationError(msg) if transit_road_net_consistency(self.feed, road_net_in): self._road_net = road_net_in - self._stored_road_net_hash = copy.deepcopy(road_net_in.network_hash) + self._stored_road_net_version = road_net_in.modification_version self._consistent_with_road_net = True else: msg = "Can't assign inconsistent RoadwayNetwork - Roadway Network not \ @@ -165,9 +165,8 @@ def consistent_with_road_net(self) -> bool: Will return True if road_net is None, but provide a warning. - Checks the network hash of when consistency was last evaluated. If transit network or - roadway network has changed, will re-evaluate consistency and return the updated value and - update self._stored_road_net_hash. + Checks the modification version of when consistency was last evaluated. If transit network + or roadway network has changed, will re-evaluate consistency and return the updated value. Returns: Boolean indicating if road_net is consistent with transit network. @@ -175,13 +174,13 @@ def consistent_with_road_net(self) -> bool: if self.road_net is None: WranglerLogger.warning("Roadway Network not set, cannot accurately check consistency.") return True - updated_road = self.road_net.network_hash != self._stored_road_net_hash - updated_feed = self.feed_hash != self._stored_feed_hash + updated_road = self.road_net.modification_version != self._stored_road_net_version + updated_feed = self.feed.modification_version != self._stored_feed_version if updated_road or updated_feed: self._consistent_with_road_net = transit_road_net_consistency(self.feed, self.road_net) - self._stored_road_net_hash = copy.deepcopy(self.road_net.network_hash) - self._stored_feed_hash = copy.deepcopy(self.feed_hash) + self._stored_road_net_version = self.road_net.modification_version + self._stored_feed_version = self.feed.modification_version return self._consistent_with_road_net def __deepcopy__(self, memo): @@ -274,14 +273,14 @@ def get_selection( raise TransitSelectionEmptyError(msg) return self._selections[key] - def apply(self, project_card: Union[ProjectCard, dict], **kwargs) -> TransitNetwork: + def apply(self, project_card: ProjectCard | dict, **kwargs) -> TransitNetwork: """Wrapper method to apply a roadway project, returning a new TransitNetwork instance. Args: project_card: either a dictionary of the project card object or ProjectCard instance **kwargs: keyword arguments to pass to project application """ - if not (isinstance(project_card, (ProjectCard, SubProject))): + if not (isinstance(project_card, ProjectCard | SubProject)): project_card = ProjectCard(project_card) if not project_card.valid: @@ -298,8 +297,8 @@ def apply(self, project_card: Union[ProjectCard, dict], **kwargs) -> TransitNetw def _apply_change( self, - change: Union[ProjectCard, SubProject], - reference_road_net: Optional[RoadwayNetwork] = None, + change: ProjectCard | SubProject, + reference_road_net: RoadwayNetwork | None = None, ) -> TransitNetwork: """Apply a single change: a single-project project or a sub-project.""" if not isinstance(change, SubProject): diff --git a/network_wrangler/transit/projects/add_route.py b/network_wrangler/transit/projects/add_route.py index 0e76df78..884c21d0 100644 --- a/network_wrangler/transit/projects/add_route.py +++ b/network_wrangler/transit/projects/add_route.py @@ -4,7 +4,7 @@ import copy from datetime import datetime -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import pandas as pd from pandera.typing import DataFrame as paDataFrame @@ -21,7 +21,6 @@ ) from ...utils.data import concat_with_attr from ...utils.ids import create_str_int_combo_ids -from ...utils.models import fill_df_with_defaults_from_model from ...utils.time import str_to_time_list if TYPE_CHECKING: @@ -33,7 +32,7 @@ def apply_transit_route_addition( net: TransitNetwork, transit_route_addition: dict, - reference_road_net: Optional[RoadwayNetwork] = None, + reference_road_net: RoadwayNetwork | None = None, ) -> TransitNetwork: """Add transit route to TransitNetwork. @@ -99,7 +98,7 @@ def _add_route_to_feed( WranglerLogger.debug(f"Adding {len(route['trips'])} trips for route {route['route_id']}.") shape_ids = create_str_int_combo_ids(len(route["trips"]), shapes_df["shape_id"]) - for trip, shape_id in zip(route["trips"], shape_ids): + for trip, shape_id in zip(route["trips"], shape_ids, strict=True): add_shape_df = _create_new_shape(trip["routing"], shape_id, road_net) shapes_df = concat_with_attr([shapes_df, add_shape_df], ignore_index=True, sort=False) @@ -164,7 +163,7 @@ def _create_new_trips( def _create_new_shape( - routing: list[Union[dict, int]], shape_id: str, road_net: RoadwayNetwork + routing: list[dict | int], shape_id: str, road_net: RoadwayNetwork ) -> paDataFrame[WranglerShapesTable]: """Create new shape for a trip. @@ -178,7 +177,7 @@ def _create_new_shape( int(next(iter(item.keys()))) if isinstance(item, dict) else int(item) for item in routing ] coords = [road_net.node_coords(n) for n in shape_model_node_id_list] - lon, lat = zip(*coords) + lon, lat = zip(*coords, strict=True) add_shapes_df = pd.DataFrame( { "shape_model_node_id": shape_model_node_id_list, @@ -191,7 +190,7 @@ def _create_new_shape( return add_shapes_df -def _get_stops_from_routing(routing: list[Union[dict, int]]) -> list[dict]: +def _get_stops_from_routing(routing: list[dict | int]) -> list[dict]: """Converts a routing list to stop_id_list, drop_off_type, and pickup_type. Default for board and alight is True unless specified to be False. @@ -230,7 +229,7 @@ def _get_stops_from_routing(routing: list[Union[dict, int]]) -> list[dict]: def _create_new_stop_times( - trip_routing: list[Union[dict, int]], trip_id: str + trip_routing: list[dict | int], trip_id: str ) -> paDataFrame[WranglerStopTimesTable]: """Create new stop times for a trip. @@ -274,7 +273,7 @@ def _create_new_stops( add_stops_df = pd.DataFrame(columns=["stop_id", "stop_lat", "stop_lon"]) if add_stop_ids.size: coords = [road_net.node_coords(n) for n in add_stop_ids] - lon, lat = zip(*coords) + lon, lat = zip(*coords, strict=True) add_stops_df = pd.DataFrame({"stop_id": add_stop_ids, "stop_lat": lat, "stop_lon": lon}) return add_stops_df diff --git a/network_wrangler/transit/projects/delete_service.py b/network_wrangler/transit/projects/delete_service.py index 7b609f4a..8351f06f 100644 --- a/network_wrangler/transit/projects/delete_service.py +++ b/network_wrangler/transit/projects/delete_service.py @@ -2,8 +2,7 @@ from __future__ import annotations -import copy -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from ...logger import WranglerLogger from ..feed.routes import route_ids_for_trip_ids @@ -18,8 +17,8 @@ def apply_transit_service_deletion( net: TransitNetwork, selection: TransitSelection, - clean_shapes: Optional[bool] = False, - clean_routes: Optional[bool] = False, + clean_shapes: bool | None = False, + clean_routes: bool | None = False, ) -> TransitNetwork: """Delete transit service to TransitNetwork. @@ -47,8 +46,8 @@ def apply_transit_service_deletion( def _delete_trips_from_feed( feed: Feed, trip_ids: list, - clean_shapes: Optional[bool] = False, - clean_routes: Optional[bool] = False, + clean_shapes: bool | None = False, + clean_routes: bool | None = False, ) -> Feed: """Delete transit service from feed based on trip_ids. diff --git a/network_wrangler/transit/projects/edit_property.py b/network_wrangler/transit/projects/edit_property.py index 14667d1b..3b7b1265 100644 --- a/network_wrangler/transit/projects/edit_property.py +++ b/network_wrangler/transit/projects/edit_property.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from ...errors import ProjectCardError, TransitPropertyChangeError from ...logger import WranglerLogger @@ -24,7 +24,7 @@ def apply_transit_property_change( net: TransitNetwork, selection: TransitSelection, property_changes: dict, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> TransitNetwork: """Apply changes to transit properties. @@ -68,7 +68,7 @@ def _apply_transit_property_change_to_table( selection: TransitSelection, prop_name: str, prop_change: dict, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> TransitNetwork: table_name = _get_table_name_for_property(net, prop_name) WranglerLogger.debug(f"...modifying {prop_name} in {table_name}.") diff --git a/network_wrangler/transit/projects/edit_routing.py b/network_wrangler/transit/projects/edit_routing.py index 4d96f130..7660cbcf 100644 --- a/network_wrangler/transit/projects/edit_routing.py +++ b/network_wrangler/transit/projects/edit_routing.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -19,7 +19,7 @@ WranglerTripsTable, ) from ...utils.data import concat_with_attr, segment_data_by_selection_min_overlap -from ...utils.ids import generate_list_of_new_ids_from_existing, generate_new_id_from_existing +from ...utils.ids import generate_new_id_from_existing from ...utils.models import validate_df_to_model from ..feed.shapes import ( find_nearest_stops, @@ -42,7 +42,7 @@ def _create_stop_times( - set_stops_node_ids: list[int], trip_id: str, project_name: Optional[str] = None + set_stops_node_ids: list[int], trip_id: str, project_name: str | None = None ) -> DataFrame[WranglerStopTimesTable]: """Modifies a list of nodes from project card routing key to a shape dataframe. @@ -71,7 +71,7 @@ def _create_shapes( nodes_list: list[int], shape_id: str, road_net: RoadwayNetwork, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> DataFrame[WranglerShapesTable]: """Modifies a list of nodes from project card routing key to rows in a shapes.txt dataframe. @@ -123,7 +123,7 @@ def _add_new_shape_copy( trip_ids: list[str], feed: Feed, id_scalar: int = DefaultConfig.IDS.TRANSIT_SHAPE_ID_SCALAR, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> tuple[DataFrame[WranglerShapesTable], DataFrame[WranglerTripsTable], str]: """Create an identical new shape_id from shape matching old_shape_id for the trip_ids. @@ -161,7 +161,7 @@ def _replace_shapes_segment( set_routing: list[int], feed: Feed, road_net: RoadwayNetwork, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> DataFrame[WranglerShapesTable]: """Returns shapes with a replaced segment for a given shape_id. @@ -229,7 +229,7 @@ def _replace_stop_times_segment_for_trip( trip_id: str, set_stops_nodes: list[int], feed: Feed, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> DataFrame[WranglerStopTimesTable]: """Replaces a segment of a specific set of stop_time records with the same shape_id. @@ -330,8 +330,8 @@ def _update_shapes_and_trips( routing_set: list[int], shape_id_scalar: int, road_net: RoadwayNetwork, - routing_existing: Optional[list[int]] = None, - project_name: Optional[str] = None, + routing_existing: list[int] | None = None, + project_name: str | None = None, ) -> tuple[DataFrame[WranglerShapesTable], DataFrame[WranglerTripsTable]]: """Update shapes and trips for transit routing change. @@ -395,7 +395,7 @@ def _update_stops( feed: Feed, routing_set: list[int], road_net: RoadwayNetwork, - project_name: Optional[str] = None, + project_name: str | None = None, ) -> DataFrame[WranglerStopsTable]: """Update stops for transit routing change. @@ -475,7 +475,7 @@ def _update_stop_times_for_trip( trip_id: str, routing_set: list[int], routing_existing: list[int], - project_name: Optional[str] = None, + project_name: str | None = None, ) -> DataFrame[WranglerStopTimesTable]: """Update stop_times for a specific trip with new stop_times. @@ -545,8 +545,8 @@ def apply_transit_routing_change( net: TransitNetwork, selection: TransitSelection, routing_change: dict, - reference_road_net: Optional[RoadwayNetwork] = None, - project_name: Optional[str] = None, + reference_road_net: RoadwayNetwork | None = None, + project_name: str | None = None, ) -> TransitNetwork: """Apply a routing change to the transit network, including stop updates. diff --git a/network_wrangler/transit/selection.py b/network_wrangler/transit/selection.py index 269037be..ddd4eee3 100644 --- a/network_wrangler/transit/selection.py +++ b/network_wrangler/transit/selection.py @@ -29,7 +29,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from pandera.typing import DataFrame @@ -76,7 +76,7 @@ class TransitSelection: def __init__( self, net: TransitNetwork, - selection_dict: Union[dict, SelectTransitTrips], + selection_dict: dict | SelectTransitTrips, ): """Constructor for TransitSelection object. @@ -90,7 +90,7 @@ def __init__( # Initialize self._selected_trips_df = None self.sel_key = dict_to_hexkey(selection_dict) - self._stored_feed_hash = copy.deepcopy(self.net.feed.hash) + self._stored_feed_version = self.net.feed.modification_version WranglerLogger.debug(f"...created TransitSelection object: {selection_dict}") @@ -104,10 +104,10 @@ def selection_dict(self): return self._selection_dict @selection_dict.setter - def selection_dict(self, value: Union[dict, SelectTransitTrips]): + def selection_dict(self, value: dict | SelectTransitTrips): self._selection_dict = self.validate_selection_dict(value) - def validate_selection_dict(self, selection_dict: Union[dict, SelectTransitTrips]) -> dict: + def validate_selection_dict(self, selection_dict: dict | SelectTransitTrips) -> dict: """Check that selection dictionary has valid and used properties consistent with network. Checks that selection_dict is a valid TransitSelectionDict: @@ -151,17 +151,19 @@ def selected_trips(self) -> list: def selected_trips_df(self) -> DataFrame[WranglerTripsTable]: """Lazily evaluates selection for trips or returns stored value in self._selected_trips_df. - Will re-evaluate if the current network hash is different than the stored one from the - last selection. + Will re-evaluate if the current feed modification version is different than the stored + one from the last selection. Returns: DataFrame[WranglerTripsTable] of selected trips """ - if (self._selected_trips_df is not None) and self._stored_feed_hash == self.net.feed_hash: + if ( + self._selected_trips_df is not None + ) and self._stored_feed_version == self.net.feed.modification_version: return self._selected_trips_df self._selected_trips_df = self._select_trips() - self._stored_feed_hash = copy.deepcopy(self.net.feed_hash) + self._stored_feed_version = self.net.feed.modification_version return self._selected_trips_df @property @@ -249,7 +251,7 @@ def _filter_trips_by_selection_dict( def _filter_trips_by_links( trips_df: DataFrame[WranglerTripsTable], shapes_df: DataFrame[WranglerShapesTable], # noqa: ARG001 - select_links: Union[SelectTransitLinks, None], + select_links: SelectTransitLinks | None, ) -> DataFrame[WranglerTripsTable]: if select_links is None: return trips_df diff --git a/network_wrangler/transit/validate.py b/network_wrangler/transit/validate.py index 7497d3b4..29982c28 100644 --- a/network_wrangler/transit/validate.py +++ b/network_wrangler/transit/validate.py @@ -3,12 +3,9 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -import ijson import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq from pandera.errors import SchemaErrors from pandera.typing import DataFrame @@ -147,7 +144,7 @@ def transit_road_net_consistency(feed: Feed, road_net: RoadwayNetwork) -> bool: def validate_transit_in_dir( dir: Path, file_format: TransitFileTypes = "txt", - road_dir: Optional[Path] = None, + road_dir: Path | None = None, road_file_format: RoadwayFileTypes = "geojson", ) -> bool: """Validates a roadway network in a directory to the wrangler data model specifications. @@ -159,7 +156,7 @@ def validate_transit_in_dir( road_file_format (str): The format of roadway network file name. Defaults to "geojson". output_dir (str): The output directory for the validation report. Defaults to ".". """ - from .io import load_transit # noqa: PLC0415 + from .io import load_transit try: t = load_transit(dir, file_format=file_format) @@ -167,8 +164,8 @@ def validate_transit_in_dir( WranglerLogger.error(f"!!! [Transit Network invalid] - Failed Loading to Feed object\n{e}") return False if road_dir is not None: - from ..roadway import load_roadway_from_dir # noqa: PLC0415 - from .network import TransitRoadwayConsistencyError # noqa: PLC0415 + from ..roadway import load_roadway_from_dir + from .network import TransitRoadwayConsistencyError try: r = load_roadway_from_dir(road_dir, file_format=road_file_format) diff --git a/network_wrangler/utils/data.py b/network_wrangler/utils/data.py index 5e7bc154..574bdf8e 100644 --- a/network_wrangler/utils/data.py +++ b/network_wrangler/utils/data.py @@ -3,11 +3,10 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any, Optional, Union +from typing import Any import numpy as np import pandas as pd -import pyarrow as pa from geopandas import GeoDataFrame, GeoSeries from numpy import ndarray from shapely import wkt @@ -71,7 +70,7 @@ def update_df_by_col_value( destination_df: pd.DataFrame, source_df: pd.DataFrame, join_col: str, - properties: Optional[list[str]] = None, + properties: list[str] | None = None, fail_if_missing: bool = True, ) -> pd.DataFrame: """Updates destination_df with ALL values in source_df for specified props with same join_col. @@ -214,7 +213,7 @@ def _update_props_for_common_idx( return updated_df -def list_like_columns(df, item_type: Optional[type] = None) -> list[str]: +def list_like_columns(df, item_type: type | None = None) -> list[str]: """Find columns in a dataframe that contain list-like items that can't be json-serialized. Args: @@ -225,7 +224,7 @@ def list_like_columns(df, item_type: Optional[type] = None) -> list[str]: list_like_columns = [] for column in df.columns: - if df[column].apply(lambda x: isinstance(x, (list, ndarray))).any(): + if df[column].apply(lambda x: isinstance(x, list | ndarray)).any(): if item_type is not None and not isinstance(df[column].iloc[0], item_type): continue list_like_columns.append(column) @@ -233,7 +232,7 @@ def list_like_columns(df, item_type: Optional[type] = None) -> list[str]: def compare_df_values( - df1, df2, join_col: Optional[str] = None, ignore: Optional[list[str]] = None, atol=1e-5 + df1, df2, join_col: str | None = None, ignore: list[str] | None = None, atol=1e-5 ): """Compare overlapping part of dataframes and returns where there are differences.""" if ignore is None: @@ -286,7 +285,7 @@ def compare_df_values( return comp_df -def diff_dfs(df1, df2, ignore: Optional[list[str]] = None) -> bool: +def diff_dfs(df1, df2, ignore: list[str] | None = None) -> bool: """Returns True if two dataframes are different and log differences.""" if ignore is None: ignore = [] @@ -351,13 +350,13 @@ def diff_list_like_series(s1, s2) -> bool: def segment_data_by_selection( item_list: list, - data: Union[list, pd.DataFrame, pd.Series], - field: Optional[str] = None, + data: list | pd.DataFrame | pd.Series, + field: str | None = None, end_val=0, ) -> tuple[ - Union[pd.Series, list, pd.DataFrame], - Union[pd.Series, list, pd.DataFrame], - Union[pd.Series, list, pd.DataFrame], + pd.Series | list | pd.DataFrame, + pd.Series | list | pd.DataFrame, + pd.Series | list | pd.DataFrame, ]: """Segment a dataframe or series into before, middle, and end segments based on item_list. @@ -369,7 +368,7 @@ def segment_data_by_selection( Args: item_list (list): List of items to segment data by. If longer than two, will only use the first and last items. - data (Union[pd.Series, pd.DataFrame]): Data to segment into before, middle, and after. + data (pd.Series | pd.DataFrame): Data to segment into before, middle, and after. field (str, optional): If a dataframe, specifies which field to reference. Defaults to None. end_val (int, optional): Notation for util the end or from the begining. Defaults to 0. @@ -429,7 +428,7 @@ def segment_data_by_selection( selected_segment = data[start_idx:end_idx] after_segment = data[end_idx:] - if isinstance(data, (pd.DataFrame, pd.Series)): + if isinstance(data, pd.DataFrame | pd.Series): before_segment = before_segment.reset_index(drop=True) selected_segment = selected_segment.reset_index(drop=True) after_segment = after_segment.reset_index(drop=True) @@ -449,9 +448,9 @@ def segment_data_by_selection_min_overlap( ) -> tuple[ list, tuple[ - Union[pd.Series, pd.DataFrame], - Union[pd.Series, pd.DataFrame], - Union[pd.Series, pd.DataFrame], + pd.Series | pd.DataFrame, + pd.Series | pd.DataFrame, + pd.Series | pd.DataFrame, ], ]: """Segments data based on item_list reducing overlap with replacement list. @@ -474,7 +473,7 @@ def segment_data_by_selection_min_overlap( Args: selection_list (list): List of items to segment data by. If longer than two, will only use the first and last items. - data (Union[pd.Series, pd.DataFrame]): Data to segment into before, middle, and after. + data (pd.Series | pd.DataFrame): Data to segment into before, middle, and after. field (str): Specifies which field to reference. replacements_list (list): List of items to eventually replace the selected segment with. end_val (int, optional): Notation for util the end or from the begining. Defaults to 0. @@ -551,7 +550,7 @@ def validate_existing_value_in_df(df: pd.DataFrame, idx: list[int], field: str, return True -CoerceTypes = Union[str, int, float, bool, list[Union[str, int, float, bool]]] +CoerceTypes = str | int | float | bool | list[str | int | float | bool] def coerce_val_to_df_types( # noqa: PLR0911 @@ -591,7 +590,7 @@ def coerce_val_to_df_types( # noqa: PLR0911 def coerce_dict_to_df_types( d: dict[str, CoerceTypes], df: pd.DataFrame, - skip_keys: Optional[list] = None, + skip_keys: list | None = None, return_skipped: bool = False, ) -> dict[str, CoerceTypes]: """Coerce dictionary values to match the type of a dataframe columns matching dict keys. @@ -636,7 +635,7 @@ def coerce_dict_to_df_types( return coerced_dict -def coerce_val_to_series_type(val, s: pd.Series) -> Union[float, str, bool]: +def coerce_val_to_series_type(val, s: pd.Series) -> float | str | bool: """Coerces a value to match type of pandas series. Will try not to fail so if you give it a value that can't convert to a number, it will @@ -650,7 +649,7 @@ def coerce_val_to_series_type(val, s: pd.Series) -> Union[float, str, bool]: # {pd.api.types.infer_dtype(s)}.") if pd.api.types.infer_dtype(s) in ["integer", "floating"]: try: - v: Union[float, str, bool] = float(val) + v: float | str | bool = float(val) except: v = str(val) elif pd.api.types.infer_dtype(s) == "boolean": @@ -662,7 +661,7 @@ def coerce_val_to_series_type(val, s: pd.Series) -> Union[float, str, bool]: def fk_in_pk( - pk: Union[pd.Series, list], fk: Union[pd.Series, list], ignore_nan: bool = True + pk: pd.Series | list, fk: pd.Series | list, ignore_nan: bool = True ) -> tuple[bool, list]: """Check if all foreign keys are in the primary keys, optionally ignoring NaN.""" if isinstance(fk, list): @@ -695,7 +694,7 @@ def dict_fields_in_df(d: dict, df: pd.DataFrame) -> bool: def concat_with_attr(dfs: list[pd.DataFrame], **kwargs) -> pd.DataFrame: """Concatenate a list of dataframes and retain the attributes of the first dataframe.""" - import copy # noqa: PLC0415 + import copy if not dfs: msg = "No dataframes to concatenate." diff --git a/network_wrangler/utils/geo.py b/network_wrangler/utils/geo.py index 3dc39ac3..ca52d876 100644 --- a/network_wrangler/utils/geo.py +++ b/network_wrangler/utils/geo.py @@ -5,11 +5,10 @@ import copy import math from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING import geopandas as gpd import pandas as pd -import pyarrow as pa from geographiclib.geodesic import Geodesic from pyproj import CRS, Proj, Transformer from shapely.geometry import LineString, Point @@ -84,7 +83,7 @@ def offset_point_with_distance_and_bearing( return [out_lon, out_lat] -def length_of_linestring_miles(gdf: Union[gpd.GeoSeries, gpd.GeoDataFrame]) -> pd.Series: +def length_of_linestring_miles(gdf: gpd.GeoSeries | gpd.GeoDataFrame) -> pd.Series: """Returns a Series with the linestring length in miles. Args: @@ -200,7 +199,9 @@ def linestring_from_lats_lons(df, lat_fields, lon_fields) -> gpd.GeoSeries: line_geometries = gpd.GeoSeries( [ - LineString([(row[lon], row[lat]) for lon, lat in zip(lon_fields, lat_fields)]) + LineString( + [(row[lon], row[lat]) for lon, lat in zip(lon_fields, lat_fields, strict=True)] + ) for _, row in df.iterrows() ] ) @@ -331,8 +332,8 @@ def get_point_geometry_from_linestring(polyline_geometry, pos: int = 0): def location_ref_from_point( geometry: Point, sequence: int = 1, - bearing: Optional[float] = None, - distance_to_next_ref: Optional[float] = None, + bearing: float | None = None, + distance_to_next_ref: float | None = None, ) -> LocationReference: """Generates a shared street point location reference. @@ -378,9 +379,9 @@ def location_refs_from_linestring(geometry: LineString) -> list[LocationReferenc def get_bounding_polygon( - boundary_geocode: Optional[Union[str, dict]] = None, - boundary_file: Optional[Union[str, Path]] = None, - boundary_gdf: Optional[gpd.GeoDataFrame] = None, + boundary_geocode: str | dict | None = None, + boundary_file: str | Path | None = None, + boundary_gdf: gpd.GeoDataFrame | None = None, crs: int = LAT_LON_CRS, # WGS84 ) -> gpd.GeoSeries: """Get the bounding polygon for a given boundary. @@ -392,9 +393,9 @@ def get_bounding_polygon( geometry is returned as a GeoSeries. Args: - boundary_geocode (Union[str, dict], optional): A geocode string or dictionary + boundary_geocode (str | dict, optional): A geocode string or dictionary representing the boundary. Defaults to None. - boundary_file (Union[str, Path], optional): A path to the boundary file. Only used if + boundary_file (str | Path, optional): A path to the boundary file. Only used if boundary_geocode is None. Defaults to None. boundary_gdf (gpd.GeoDataFrame, optional): A GeoDataFrame representing the boundary. Only used if boundary_geocode and boundary_file are None. Defaults to None. @@ -403,7 +404,7 @@ def get_bounding_polygon( Returns: gpd.GeoSeries: The polygon geometry representing the bounding polygon. """ - import osmnx as ox # noqa: PLC0415 + import osmnx as ox nargs = sum(x is not None for x in [boundary_gdf, boundary_geocode, boundary_file]) if nargs == 0: @@ -453,7 +454,7 @@ def _harmonize_crs(df: pd.DataFrame, crs: int = LAT_LON_CRS) -> pd.DataFrame: return df -def _id_utm_crs(gdf: Union[gpd.GeoSeries, gpd.GeoDataFrame]) -> int: +def _id_utm_crs(gdf: gpd.GeoSeries | gpd.GeoDataFrame) -> int: """Returns the UTM CRS ESPG for the given GeoDataFrame. Args: @@ -484,8 +485,8 @@ def offset_geometry_meters(geo_s: gpd.GeoSeries, offset_distance_meters: float) def to_points_gdf( table: pd.DataFrame, - ref_nodes_df: Optional[gpd.GeoDataFrame] = None, - ref_road_net: Optional[RoadwayNetwork] = None, + ref_nodes_df: gpd.GeoDataFrame | None = None, + ref_road_net: RoadwayNetwork | None = None, ) -> gpd.GeoDataFrame: """Convert a table to a GeoDataFrame. diff --git a/network_wrangler/utils/io_dict.py b/network_wrangler/utils/io_dict.py index 93cfdd7e..d0ea82d6 100644 --- a/network_wrangler/utils/io_dict.py +++ b/network_wrangler/utils/io_dict.py @@ -2,11 +2,7 @@ import json from pathlib import Path -from typing import Union -import ijson -import pyarrow as pa -import pyarrow.parquet as pq import toml import yaml @@ -51,7 +47,7 @@ def load_dict(path: Path) -> dict: raise NotImplementedError(msg) -def load_merge_dict(path: Union[Path, list[Path]]) -> dict: +def load_merge_dict(path: Path | list[Path]) -> dict: """Load and merge multiple dictionaries from files.""" if not isinstance(path, list): path = [path] diff --git a/network_wrangler/utils/io_table.py b/network_wrangler/utils/io_table.py index 2cf36c93..0668ea89 100644 --- a/network_wrangler/utils/io_table.py +++ b/network_wrangler/utils/io_table.py @@ -6,7 +6,6 @@ import weakref from datetime import datetime from pathlib import Path -from typing import Optional, Union import geopandas as gpd import pandas as pd @@ -43,7 +42,7 @@ class FileWriteError(Exception): def write_table( - df: Union[pd.DataFrame, gpd.GeoDataFrame], + df: pd.DataFrame | gpd.GeoDataFrame, filename: Path, overwrite: bool = False, **kwargs, @@ -90,7 +89,7 @@ def write_table( def _estimate_read_time_of_file( - filepath: Union[str, Path], read_speed: dict = DefaultConfig.CPU.EST_PD_READ_SPEED + filepath: str | Path, read_speed: dict = DefaultConfig.CPU.EST_PD_READ_SPEED ) -> str: """Estimates read time in seconds based on a given file size and speed factor. @@ -109,12 +108,12 @@ def _estimate_read_time_of_file( def read_table( filename: Path, - sub_filename: Optional[str] = None, - boundary_gdf: Optional[gpd.GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, + sub_filename: str | None = None, + boundary_gdf: gpd.GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, read_speed: dict = DefaultConfig.CPU.EST_PD_READ_SPEED, -) -> Union[pd.DataFrame, gpd.GeoDataFrame]: +) -> pd.DataFrame | gpd.GeoDataFrame: """Read file and return a dataframe or geodataframe. If filename is a zip file, will unzip to a temporary directory. @@ -184,7 +183,7 @@ def read_table( raise NotImplementedError(msg) -def _read_parquet_table(filename, mask_gdf) -> Union[gpd.GeoDataFrame, pd.DataFrame]: +def _read_parquet_table(filename, mask_gdf) -> gpd.GeoDataFrame | pd.DataFrame: """Read a parquet file and filter to a bounding box if provided. Converts numpy arrays to lists. @@ -220,11 +219,11 @@ def convert_file_serialization( input_file: Path, output_file: Path, overwrite: bool = True, - boundary_gdf: Optional[gpd.GeoDataFrame] = None, - boundary_geocode: Optional[str] = None, - boundary_file: Optional[Path] = None, - node_filter_s: Optional[pd.Series] = None, - chunk_size: Optional[int] = None, + boundary_gdf: gpd.GeoDataFrame | None = None, + boundary_geocode: str | None = None, + boundary_file: Path | None = None, + node_filter_s: pd.Series | None = None, + chunk_size: int | None = None, ): """Convert a file serialization format to another and optionally filter to a boundary. @@ -298,7 +297,7 @@ def _estimate_bytes_per_json_object(json_path: Path) -> float: return total_size / len(json_objects) -def _suggest_json_chunk_size(json_path: Path, memory_fraction: float = 0.6) -> Union[None, int]: +def _suggest_json_chunk_size(json_path: Path, memory_fraction: float = 0.6) -> int | None: """Ascertain if a file should be processed in chunks and how large the chunks should be in mb. Args: @@ -321,8 +320,8 @@ def _suggest_json_chunk_size(json_path: Path, memory_fraction: float = 0.6) -> U def _append_parquet_table( new_data: pd.DataFrame, file_counter=1, - base_filename: Optional[str] = None, - directory: Optional[Path] = None, + base_filename: str | None = None, + directory: Path | None = None, ) -> Path: """Append new data to a Parquet dataset directory. @@ -338,8 +337,8 @@ def _append_parquet_table( Returns: Path: The path to the output directory. """ - import pyarrow as pa # noqa: PLC0415 - import pyarrow.parquet as pq # noqa: PLC0415 + import pyarrow as pa + import pyarrow.parquet as pq if directory is None: temp_dir = tempfile.mkdtemp() @@ -366,12 +365,12 @@ def _json_to_parquet_in_chunks(input_file: Path, output_file: Path, chunk_size: chunk_size: Number of JSON objects to process in each chunk. """ try: - import ijson # noqa: PLC0415 + import ijson except ModuleNotFoundError as err: msg = "ijson is required for chunked JSON processing." raise ModuleNotFoundError(msg) from err - import pyarrow.parquet as pq # noqa: PLC0415 + import pyarrow.parquet as pq base_filename = Path(output_file).stem directory = None diff --git a/network_wrangler/utils/models.py b/network_wrangler/utils/models.py index 84e0cb4d..7f1d3bf5 100644 --- a/network_wrangler/utils/models.py +++ b/network_wrangler/utils/models.py @@ -3,7 +3,8 @@ import copy from functools import wraps from pathlib import Path -from typing import Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints +from types import UnionType +from typing import Union, _GenericAlias, get_args, get_origin, get_type_hints import geopandas as gpd import pandas as pd @@ -20,6 +21,21 @@ from .data import coerce_val_to_df_types +# Convert StringDtype columns to object dtype to avoid numpy.issubdtype compatibility issues +# in pandas 2.2+ with Python 3.11+ +def _convert_string_dtype_to_object(df: DataFrame) -> DataFrame: + """Convert StringDtype columns to object dtype for compatibility with numpy.issubdtype. + + This fixes compatibility issues with pandas 2.2+ StringDtype and numpy.issubdtype + in Python 3.11+ when used with pandera validation. + """ + df = df.copy() + for col in df.columns: + if isinstance(df[col].dtype, pd.StringDtype): + df[col] = df[col].astype(object) + return df + + class DatamodelDataframeIncompatableError(Exception): """Raised when a data model and a dataframe are not compatable.""" @@ -30,7 +46,7 @@ class TableValidationError(Exception): def empty_df_from_datamodel( model: DataFrameModel, crs: int = LAT_LON_CRS -) -> Union[gpd.GeoDataFrame, pd.DataFrame]: +) -> gpd.GeoDataFrame | pd.DataFrame: """Create an empty DataFrame or GeoDataFrame with the specified columns. Args: @@ -88,6 +104,9 @@ def validate_df_to_model( attrs = copy.deepcopy(df.attrs) err_msg = f"Validation to {model.__name__} failed." try: + # Convert StringDtype columns to object dtype before validation to avoid + # numpy.issubdtype compatibility issues with pandas 2.2+ and Python 3.11+ + df = _convert_string_dtype_to_object(df) model_df = model.validate(df, lazy=True) model_df = fill_df_with_defaults_from_model(model_df, model) model_df.attrs = attrs @@ -117,13 +136,11 @@ def validate_df_to_model( raise TableValidationError(err_msg) from e -def identify_model( - data: Union[pd.DataFrame, dict], models: list -) -> Union[DataFrameModel, BaseModel]: +def identify_model(data: pd.DataFrame | dict, models: list) -> DataFrameModel | BaseModel: """Identify the model that the input data conforms to. Args: - data (Union[pd.DataFrame, dict]): The input data to identify. + data (pd.DataFrame | dict): The input data to identify. models (list[DataFrameModel,BaseModel]): A list of models to validate the input data against. """ @@ -156,7 +173,7 @@ def extra_attributes_undefined_in_model(instance: BaseModel, model: BaseModel) - return extra_attributes -def submodel_fields_in_model(model: type, instance: Optional[BaseModel] = None) -> list: +def submodel_fields_in_model(model: type, instance: BaseModel | None = None) -> list: """Find the fields in a pydantic model that are submodels.""" types = get_type_hints(model) model_type = (ModelMetaclass, BaseModel) @@ -210,7 +227,8 @@ def check_type_hint(value): pass return False - if get_origin(type_hint_value) is Union: + origin = get_origin(type_hint_value) + if origin is Union or origin is UnionType: args = get_args(type_hint_value) for arg in args: if check_type_hint(arg): diff --git a/network_wrangler/utils/time.py b/network_wrangler/utils/time.py index 76cec654..84f510d3 100644 --- a/network_wrangler/utils/time.py +++ b/network_wrangler/utils/time.py @@ -15,7 +15,6 @@ from __future__ import annotations from datetime import date, datetime, timedelta -from typing import Optional, Union import pandas as pd from pydantic import validate_call @@ -30,7 +29,7 @@ class TimespanDfQueryError(Exception): @validate_call(config={"arbitrary_types_allowed": True}) -def str_to_time(time_str: TimeString, base_date: Optional[date] = None) -> datetime: +def str_to_time(time_str: TimeString, base_date: date | None = None) -> datetime: """Convert TimeString (HH:MM<:SS>) to datetime object. If HH > 24, will subtract 24 to be within 24 hours. Timespans will be treated as the next day. @@ -65,7 +64,7 @@ def str_to_time(time_str: TimeString, base_date: Optional[date] = None) -> datet def _all_str_to_time_series( - time_str_s: pd.Series, base_date: Optional[Union[pd.Series, date]] = None + time_str_s: pd.Series, base_date: pd.Series | date | None = None ) -> pd.Series: """Assume all are strings and convert to datetime objects.""" # check strings are in the correct format, leave existing date times alone @@ -102,7 +101,7 @@ def _all_str_to_time_series( def str_to_time_series( - time_str_s: pd.Series, base_date: Optional[Union[pd.Series, date]] = None + time_str_s: pd.Series, base_date: pd.Series | date | None = None ) -> pd.Series: """Convert mixed panda series datetime and TimeString (HH:MM<:SS>) to datetime object. @@ -218,7 +217,7 @@ def filter_df_to_max_overlapping_timespans( query_timespan: list[TimeString], strict_match: bool = False, min_overlap_minutes: int = 1, - keep_max_of_cols: Optional[list[str]] = None, + keep_max_of_cols: list[str] | None = None, ) -> pd.DataFrame: """Filters dataframe for entries that have maximum overlap with the given query timespan. diff --git a/network_wrangler/utils/utils.py b/network_wrangler/utils/utils.py index d31a4aaa..f87f7c3e 100644 --- a/network_wrangler/utils/utils.py +++ b/network_wrangler/utils/utils.py @@ -2,7 +2,6 @@ import hashlib import re -from typing import Union from pydantic import validate_call @@ -51,7 +50,7 @@ def _topology_sort_util(vertex): def make_slug(text: str, delimiter: str = "_") -> str: """Makes a slug from text.""" text = re.sub("[,.;@#?!&$']+", "", text.lower()) - return re.sub("[\ ]+", delimiter, text) + return re.sub(r"[ ]+", delimiter, text) def delete_keys_from_dict(dictionary: dict, keys: list) -> dict: @@ -79,14 +78,14 @@ def delete_keys_from_dict(dictionary: dict, keys: list) -> dict: return modified_dict -def get_overlapping_range(ranges: list[Union[tuple[int, int], range]]) -> Union[None, range]: +def get_overlapping_range(ranges: list[tuple[int, int] | range]) -> range | None: """Returns the overlapping range for a list of ranges or tuples defining ranges. Args: - ranges (list[Union[tuple[int], range]]): A list of ranges or tuples defining ranges. + ranges (list[tuple[int, int] | range]): A list of ranges or tuples defining ranges. Returns: - Union[None, range]: The overlapping range if found, otherwise None. + range | None: The overlapping range if found, otherwise None. Example: >>> ranges = [(1, 5), (3, 7), (6, 10)] @@ -200,7 +199,7 @@ def combine_unique_unhashable_list(list1: list, list2: list): return [item for item in list1 if item not in list2] + list2 -def normalize_to_lists(mixed_list: list[Union[str, list]]) -> list[list]: +def normalize_to_lists(mixed_list: list[str | list]) -> list[list]: """Turn a mixed list of scalars and lists into a list of lists.""" normalized_list = [] for item in mixed_list: @@ -212,7 +211,7 @@ def normalize_to_lists(mixed_list: list[Union[str, list]]) -> list[list]: @validate_call -def list_elements_subset_of_single_element(mixed_list: list[Union[str, list[str]]]) -> bool: +def list_elements_subset_of_single_element(mixed_list: list[str | list[str]]) -> bool: """Find the first list in the mixed_list.""" potential_supersets = [] for item in mixed_list: @@ -234,7 +233,7 @@ def list_elements_subset_of_single_element(mixed_list: list[Union[str, list[str] def check_one_or_one_superset_present( - mixed_list: list[Union[str, list[str]]], all_fields_present: list[str] + mixed_list: list[str | list[str]], all_fields_present: list[str] ) -> bool: """Checks that exactly one of the fields in mixed_list is in fields_present or one superset.""" normalized_list = normalize_to_lists(mixed_list) diff --git a/network_wrangler/viz.py b/network_wrangler/viz.py index 99b37067..cbe745c9 100644 --- a/network_wrangler/viz.py +++ b/network_wrangler/viz.py @@ -9,10 +9,8 @@ import os import subprocess from pathlib import Path -from typing import Optional, Union import geopandas as gpd -import pyarrow as pa from .logger import WranglerLogger from .roadway.network import RoadwayNetwork @@ -24,8 +22,8 @@ class MissingMapboxTokenError(Exception): def net_to_mapbox( - roadway: Optional[Union[RoadwayNetwork, gpd.GeoDataFrame, str, Path]] = None, - transit: Optional[Union[TransitNetwork, gpd.GeoDataFrame]] = None, + roadway: RoadwayNetwork | gpd.GeoDataFrame | str | Path | None = None, + transit: TransitNetwork | gpd.GeoDataFrame | None = None, roadway_geojson_out: Path = Path("roadway_shapes.geojson"), transit_geojson_out: Path = Path("transit_shapes.geojson"), mbtiles_out: Path = Path("network.mbtiles"), diff --git a/pyproject.toml b/pyproject.toml index cb496e33..eb65620d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "" license = {file = "LICENSE"} readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" authors = [ { name = "Elizabeth Sall", email = "elizabeth@urbanlabs.io" }, { name = "Sijia Wang", email = "Sijia.Wang@wsp.com"}, @@ -22,6 +22,10 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ "fiona>=1.10.1", @@ -30,7 +34,7 @@ dependencies = [ "geopandas>=1.0.1", "ijson>=3.3.0", "osmnx>=1.9.3", - "pandas>=2.2.3", + "pandas>=2.2.3,<3.0", "pandera[pandas,geopandas]>=0.24.0", "projectcard>=0.3.3", "psutil>=6.0.0", @@ -50,7 +54,7 @@ viz = [ docs = [ "fontawesome_markdown", "jupyter", - "markdown == 3.3.1", + "markdown>=3.4.0,<4.0", # 3.4.0+ supports new EntryPoints API "mike", "mkdocs", "mkdocs-autorefs", @@ -113,7 +117,7 @@ source = ["network_wrangler"] [tool.mypy] files = "network_wrangler" -python_version = "3.9" +python_version = "3.10" follow_imports = "skip" show_error_codes = true show_error_context = true @@ -130,12 +134,9 @@ exclude = ["notebook/*.ipynb"] select = ["D"] ignore = [ "RET504", # Unnecessary assignment before `return` statement - "UP007", # non pep-604 annotations. + "PLC0415", # `import` should be at the top-level (intentional lazy imports to avoid circular deps) "C416", # non pep-604 annotations. - "UP007", # non pep-604 annotations. "PLR0913", # too many args - "UP045", # Use `X | None` for type annotations - "UP006", # Use `list` instead of `List` for type annotation ] extend-select = [ "B", # flake8-bugbear diff --git a/requirements.docs.txt b/requirements.docs.txt index a3593734..5e9232f5 100644 --- a/requirements.docs.txt +++ b/requirements.docs.txt @@ -1,6 +1,6 @@ fontawesome_markdown jupyter -markdown == 3.3.1 # needs to be compatible with mkdocs, which needs > markdown 3.2.1 +markdown>=3.4.0,<4.0 # 3.4.0+ supports new EntryPoints API, compatible with mkdocs (>3.2.1) mike mkdocs mkdocs-autorefs diff --git a/roadway/clip.py b/roadway/clip.py index c594de0d..f5bcf04c 100644 --- a/roadway/clip.py +++ b/roadway/clip.py @@ -1 +1,3 @@ +"""Roadway network clipping utilities.""" + from .network import RoadwayNetwork diff --git a/roadway/network.py b/roadway/network.py index 8ae265b8..b2c8eef8 100644 --- a/roadway/network.py +++ b/roadway/network.py @@ -1,3 +1,5 @@ +"""Roadway network module exports.""" + from .graph import net_to_graph from .links.scopes import prop_for_scope from .nodes.nodes import node_ids_without_links diff --git a/roadway/nodes/validate.py b/roadway/nodes/validate.py index d775ea99..afcbc770 100644 --- a/roadway/nodes/validate.py +++ b/roadway/nodes/validate.py @@ -1 +1,3 @@ +"""Roadway node validation utilities.""" + from .create import data_to_nodes_df diff --git a/roadway/segment.py b/roadway/segment.py index be63c30a..1e5f52be 100644 --- a/roadway/segment.py +++ b/roadway/segment.py @@ -1 +1,3 @@ +"""Roadway segment utilities.""" + from .network import add_incident_link_data_to_nodes diff --git a/roadway/shapes/validate.py b/roadway/shapes/validate.py index ad2f0b8e..2d893b47 100644 --- a/roadway/shapes/validate.py +++ b/roadway/shapes/validate.py @@ -1 +1,3 @@ +"""Roadway shape validation utilities.""" + from .create import df_to_shapes_df diff --git a/tests/test_docs.py b/tests/test_docs.py deleted file mode 100644 index b1fbebb0..00000000 --- a/tests/test_docs.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Tests the documentation can be built without errors.""" - -import subprocess - -from network_wrangler.logger import WranglerLogger - - -def test_mkdocs_build(request): - """Tests that the MkDocs documentation can be built without errors.""" - WranglerLogger.info(f"--Starting: {request.node.name}") - subprocess.run(["mkdocs", "build"], capture_output=True, text=True, check=True) - WranglerLogger.info(f"--Finished: {request.node.name}") diff --git a/tests/test_roadway/test_changes/test_roadway_add_delete.py b/tests/test_roadway/test_changes/test_roadway_add_delete.py index 07e581a9..63ef4651 100644 --- a/tests/test_roadway/test_changes/test_roadway_add_delete.py +++ b/tests/test_roadway/test_changes/test_roadway_add_delete.py @@ -61,7 +61,7 @@ def test_add_roadway_link_project_card(request, small_net): WranglerLogger.debug(f"New Links: \n{_new_links}") assert len(_new_links) == len(_links) assert _new_links.at[_new_link_idxs[0], "projects"] == f"{_project}," - assert set(zip(_new_links.A, _new_links.B)) == set(_expected_new_link_fks) + assert set(zip(_new_links.A, _new_links.B, strict=False)) == set(_expected_new_link_fks) WranglerLogger.info(f"--Finished: {request.node.name}") diff --git a/tests/test_roadway/test_selections.py b/tests/test_roadway/test_selections.py index e2f8a223..df34b98a 100644 --- a/tests/test_roadway/test_selections.py +++ b/tests/test_roadway/test_selections.py @@ -184,7 +184,9 @@ def test_dfhash(request, stpaul_net): ] -@pytest.mark.parametrize(("selection", "answer"), zip(TEST_SELECTIONS, answer_selected_links)) +@pytest.mark.parametrize( + ("selection", "answer"), zip(TEST_SELECTIONS, answer_selected_links, strict=False) +) def test_select_roadway_features(request, selection, answer, stpaul_net): WranglerLogger.info(f"--Starting: {request.node.name}") net = stpaul_net diff --git a/tests/test_utils/test_data.py b/tests/test_utils/test_data.py index 5f43b874..982dd82a 100644 --- a/tests/test_utils/test_data.py +++ b/tests/test_utils/test_data.py @@ -473,7 +473,7 @@ def test_segment_series_by_list(request): ) calc_answer = segment_data_by_selection(item_list, s) - for calc, exp in zip(calc_answer, exp_answer): + for calc, exp in zip(calc_answer, exp_answer, strict=False): WranglerLogger.debug(f"\ncalc: \n{calc}") WranglerLogger.debug(f"\nexp: \n{exp}") tm.assert_series_equal(calc, exp) @@ -488,7 +488,7 @@ def test_segment_df_by_list(request): exp_answer = ([1], [2, 3, 4, 3], [2, 5]) calc_answer = segment_data_by_selection(item_list, s, field="mynodes") - for calc, exp in zip(calc_answer, exp_answer): + for calc, exp in zip(calc_answer, exp_answer, strict=False): # WranglerLogger.debug(f"\ncalc:\n{calc['mynodes']}") # WranglerLogger.debug(f"\nexp:\n{exp}") assert exp == calc["mynodes"].to_list() diff --git a/transit/io.py b/transit/io.py index c81b6de1..9f8fd5e2 100644 --- a/transit/io.py +++ b/transit/io.py @@ -1,2 +1,3 @@ -# Input and output functions for transit data +"""Input and output functions for transit data.""" + from .geo import shapes_to_shape_links_gdf