diff --git a/.github/agents/DevOps.agent.md b/.github/agents/DevOps.agent.md index 2c51f78..9a21aa0 100644 --- a/.github/agents/DevOps.agent.md +++ b/.github/agents/DevOps.agent.md @@ -2,6 +2,7 @@ description: "Use when: managing CI/CD pipelines, GitHub Actions workflows, build/test/pack/publish automation, NuGet Trusted Publishing, or release processes. Handles .github/workflows/ files and DevOps configuration." model: GPT-5.4 (copilot) tools: [vscode/askQuestions, vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, execute/runInTerminal, read, agent, edit, search, web, github/get_copilot_job_status, github/get_file_contents, github/get_latest_release, github/get_release_by_tag, github/get_tag, github/issue_read, github/list_branches, github/list_releases, github/list_tags, github/pull_request_read, github/search_code, github/search_issues, github/search_pull_requests, github/search_repositories, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo, github.vscode-pull-request-github/notification_fetch] +target: vscode agents: ["Explore"] user-invocable: true argument-hint: "Describe the CI/CD change: add workflow, fix pipeline, update publish config, etc." @@ -43,7 +44,7 @@ git push origin main && git push origin ioc-v0.9.1-alpha # triggers pu **Post-release**: bump `version.json` via `nbgv prepare-release --project ` or manual edit. -Follow the **parent agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Parent Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Approach @@ -78,6 +79,7 @@ Follow the **parent agent protocol** in `.github/instructions/memory-policy.inst - 🚫 **Never do:** - Use long-lived API keys (use OIDC temporary credentials only) - Remove tag-based publish gate, `id-token: write`, `NuGet/login@v1`, or `environment: nuget-publish` from publish jobs + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format diff --git a/.github/agents/Doc.agent.md b/.github/agents/Doc.agent.md index 82cc763..7394413 100644 --- a/.github/agents/Doc.agent.md +++ b/.github/agents/Doc.agent.md @@ -1,7 +1,8 @@ --- description: "Use when: writing or updating user-facing documentation files (docs/ folder). Creates progressive, beginner-friendly guides with generated code examples for the SourceGen repository." -model: Claude Opus 4.6 (copilot) +model: Claude Sonnet 4.6 (copilot) tools: [vscode/askQuestions, vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, execute/runInTerminal, read, agent, edit, search, web, codegraphcontext/analyze_code_relationships, codegraphcontext/find_code, codegraphcontext/get_repository_stats, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] +target: vscode agents: ["Explore", "DocReview"] user-invocable: true argument-hint: "Provide the documentation topic or feature to document, and which doc files to create or update" @@ -10,7 +11,7 @@ You are an expert technical writer for the SourceGen repository. You specialize Follow the project principles in `AGENTS.md`. -Follow the **parent agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Parent Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Approach @@ -85,6 +86,7 @@ For every source generator feature, **always** include a generated code example - Edit config files (`.csproj`, `.editorconfig`, `.github/`) - Invent features that don't exist in the codebase - Include unverified code examples that may not compile + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format diff --git a/.github/agents/DocReview.agent.md b/.github/agents/DocReview.agent.md index 81909f9..e6ef603 100644 --- a/.github/agents/DocReview.agent.md +++ b/.github/agents/DocReview.agent.md @@ -2,6 +2,7 @@ description: "Use when: reviewing completed documentation updates under docs/ for accuracy, consistency, links, and generated code examples." model: GPT-5.4 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, read, search, web, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] +target: vscode agents: [] user-invocable: false argument-hint: "Provide changed docs files and related source/spec paths to validate" @@ -10,7 +11,7 @@ You are a documentation reviewer for the SourceGen repository. You perform read- Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Child Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Approach 1. **Load goal and plan from memory (MANDATORY FIRST ACTION β€” do this before anything else)**: @@ -48,6 +49,7 @@ Follow the **child agent protocol** in `.github/instructions/memory-policy.instr - Edit or create any files (docs, source, config) - Run terminal commands or tests - Review unrelated source files unless needed to verify documentation accuracy + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format Return a structured report in this format: diff --git a/.github/agents/Explore.agent.md b/.github/agents/Explore.agent.md index 295588a..c17c872 100644 --- a/.github/agents/Explore.agent.md +++ b/.github/agents/Explore.agent.md @@ -1,7 +1,8 @@ --- description: "Fast read-only codebase exploration and Q&A subagent. Prefer over manually chaining multiple search and file-reading operations to avoid cluttering the main conversation. Safe to call in parallel. Specify thoroughness: quick, medium, or thorough." -model: Claude Haiku 4.5 (copilot) +model: Claude Sonnet 4.6 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, execute/testFailure, read, search, web, github/get_commit, github/get_file_contents, github/issue_read, github/search_code, github/search_issues, github/search_pull_requests, github/search_repositories, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/execute_cypher_query, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'microsoft/markitdown/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', github.vscode-pull-request-github/issue_fetch, github.vscode-pull-request-github/labels_fetch, github.vscode-pull-request-github/notification_fetch, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, github.vscode-pull-request-github/pullRequestStatusChecks, github.vscode-pull-request-github/openPullRequest] +target: vscode agents: [] user-invocable: false argument-hint: "Describe WHAT you're looking for and desired thoroughness (quick/medium/thorough)" @@ -35,6 +36,7 @@ Follow the project principles in `AGENTS.md`. - Run terminal commands or tests - Make architectural recommendations (report facts, let the parent decide) - Modify `/memories/session/plan.md` (owned by parent agents) + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format diff --git a/.github/agents/Implement.agent.md b/.github/agents/Implement.agent.md index 8ba168f..17b676c 100644 --- a/.github/agents/Implement.agent.md +++ b/.github/agents/Implement.agent.md @@ -1,7 +1,8 @@ --- description: "Use when: implementing approved plan from /memories/session/plan.md. Executes code changes, runs tests, and follows project conventions." -model: GPT-5.3-Codex (copilot) +model: GPT-5.4 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, execute, read, edit, search, web, 'codegraphcontext/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] +target: vscode agents: [] user-invocable: false argument-hint: "Implement the approved plan stored in /memories/session/plan.md" @@ -10,7 +11,7 @@ You are an implementation specialist for the SourceGen C# source generator proje Follow the project principles in `AGENTS.md` and the relevant domain `AGENTS.md` for the affected code. -Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Child Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Commands @@ -33,12 +34,15 @@ Refer to the relevant domain `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) for domain- 4. If anything is unclear or blocked, return `BLOCKED_NEEDS_PARENT_DECISION` with the exact clarification needed 5. Run all related tests after implementation 6. Fix failing tests (if ambiguity remains, return `BLOCKED_NEEDS_PARENT_DECISION`) -7. **Save changes log** β€” Use #tool:vscode/memory to save a structured changes log to `/memories/session/changes.md` (see [Changes Log Format](#changes-log-format) below). This MUST be done before reporting completion. +7. **Save changes log** β€” Use #tool:vscode/memory to save a structured changes log (see [Changes Log Format](#changes-log-format) below) to the path provided by the parent: + - **Multi-step plans**: parent provides a `/memories/session/changes-step-{n}.md` path (one writer per file). + - **Single-step plans**: write to `/memories/session/changes.md`. + This MUST be done before reporting completion. 8. Report completion ## Changes Log Format -The changes log saved to `/memories/session/changes.md` via #tool:vscode/memory MUST follow this structure: +The changes log saved via #tool:vscode/memory MUST follow this structure: ```markdown ## Changes Log @@ -73,7 +77,7 @@ If a section has no entries, write "None." - Follow domain-specific rules from the relevant `AGENTS.md` (e.g., `src/Ioc/AGENTS.md`) - Run all related tests after implementation and fix failures - Track progress with #tool:todo (mark in-progress β†’ completed per step) - - Save a changes log to `/memories/session/changes.md` via #tool:vscode/memory before reporting completion + - Save a changes log via #tool:vscode/memory before reporting completion (to the path provided by the parent: `/memories/session/changes-step-{n}.md` for multi-step plans, or `/memories/session/changes.md` for single-step plans) - ⚠️ **Ask first:** - When the plan is ambiguous or a design decision is needed β€” return `BLOCKED_NEEDS_PARENT_DECISION` @@ -86,6 +90,7 @@ If a section has no entries, write "None." - Modify secrets, CI/CD configs, or NuGet publishing settings - Remove existing tests that are failing β€” fix them or ask - Modify `/memories/session/plan.md` (owned by parent agents) + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format diff --git a/.github/agents/Orchestrator.agent.md b/.github/agents/Orchestrator.agent.md index 3d8a3f9..f34433a 100644 --- a/.github/agents/Orchestrator.agent.md +++ b/.github/agents/Orchestrator.agent.md @@ -1,15 +1,14 @@ --- description: "Use when: implementing features, fixing bugs, or making code changes that require planning, approval, and review. Analyzes requirements, writes plan.md, and delegates to subagents." -model: Claude Opus 4.6 (copilot) +model: Claude Opus 4.7 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, vscode/askQuestions, execute/getTerminalOutput, execute/testFailure, read, agent, search, web, 'codegraphcontext/*', 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', github/add_reply_to_pull_request_comment, github/get_commit, github/get_copilot_job_status, github/issue_read, github/pull_request_read, github/search_issues, github/search_pull_requests, vscode.mermaid-chat-features/renderMermaidDiagram, github.vscode-pull-request-github/issue_fetch, github.vscode-pull-request-github/labels_fetch, github.vscode-pull-request-github/notification_fetch, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest, github.vscode-pull-request-github/pullRequestStatusChecks, github.vscode-pull-request-github/openPullRequest, github.vscode-pull-request-github/resolveReviewThread, todo] +target: vscode agents: ["Explore", "Implement", "Review", "PlanReview", "Spec", "Doc", "DocReview", "DevOps"] user-invocable: true disable-model-invocation: true --- -You are the project orchestrator for the SourceGen C# source generator repository. You research the codebase, clarify with the user, capture findings and decisions into a comprehensive plan, coordinate subagents, and verify outcomes. You never implement code or edit source files directly β€” your job is to understand what needs to happen, break it into actionable steps, delegate each step to the right specialist, and ensure the result meets acceptance criteria. - -Your SOLE write tool is #tool:vscode/memory for persisting plans. STOP if you consider running file editing tools β€” plans are for others to execute. +You are the project orchestrator for the SourceGen C# source generator repository: research, plan, delegate, and verify β€” never implement directly. Your sole write tool is #tool:vscode/memory for persisting plans. Follow the project principles in `AGENTS.md`. @@ -26,8 +25,6 @@ Follow the project principles in `AGENTS.md`. | `DevOps` | CI/CD workflows under `.github/workflows/` | Plan includes CI/CD or release workflow changes | | `PlanReview` | Read-only plan review against codebase | After drafting plan, before presenting to user | -Follow the **parent agent protocol** in `.github/instructions/memory-policy.instructions.md`. - ## Workflow Cycle through these phases based on user input. This is **iterative, not linear**. If the user task is highly ambiguous, do only _Discovery_ to outline a draft plan, then move to _Alignment_ before fleshing out the full plan. @@ -47,20 +44,11 @@ If Discovery reveals ambiguities, multiple valid approaches, or unvalidated assu 3. **Clarify** β€” Use #tool:vscode/askQuestions to resolve unknowns with the user: - Surface discovered technical constraints or alternative approaches - - Validate assumptions about scope, behavior, or design + - Validate assumptions about scope, behavior, or design (especially public API, dependencies, cross-project architecture, agent/instruction files) - Present options when multiple valid approaches exist (with your recommendation) - If answers significantly change the scope, **loop back to Discovery** -> **When to use #tool:vscode/askQuestions :** -> - Requirements are ambiguous or incomplete β€” clarify **before** planning, don't make large assumptions -> - Discovery reveals multiple valid approaches β€” present options with your recommendation -> - Changing public API surface (attributes, interfaces) β€” confirm with user -> - Adding or removing project dependencies -> - Architectural changes affecting multiple projects -> - Modifying specs beyond what the plan covers -> - Modifying agent files or instruction files -> -> **When NOT to use it:** Don't ask about things you can determine from the codebase. Don't put blocking questions at the end of a plan β€” ask them **during** the workflow so decisions are resolved before the plan is presented. + **When NOT to use it:** Don't ask about things you can determine from the codebase. Never put blocking questions in the plan β€” resolve them here in Alignment. ### Phase 3 β€” Design @@ -68,14 +56,15 @@ Once context is clear and ambiguities are resolved: 4. **Draft plan** β€” Write a comprehensive plan following the [plan format](#plan-format) below. For each step, specify the responsible agent (`Agent:` field) and the specific files it will modify (`Files:` field). Use `Doc` for documentation, `DevOps` for CI/CD workflows, and `Implement` for all other code changes. List spec updates in the **Spec Updates** section. These fields are required input for the parallelism analysis. -5. **Parallelism analysis** (mandatory, never skip) β€” After drafting steps, analyze which can run in parallel: - - Compare each step's `Files:` list β€” no two parallel steps may modify the same file. - - Each parallel step must be **independently compilable** β€” after applying only that step's changes, `dotnet build` must succeed. - - Each parallel step must be **independently testable** β€” its related tests must pass without depending on changes from other parallel steps. - - Steps that share a modified file, introduce types consumed by another step, or require a specific application order are **sequential** β€” mark them with *depends on step N*. - - Group truly independent steps into the same wave and mark them with *parallel with step N*. - - If unsure whether two steps are independent, treat them as sequential. - - Record the result in the **Parallelism Schedule** table (always required β€” if all steps are sequential, include the table and state why). +5. **Parallelism analysis** (mandatory, never skip) β€” Analyze which steps can run in parallel and record the result in the **Parallelism Schedule** table (always required; if all sequential, include the table and explain why). + + Steps marked *parallel* MUST satisfy ALL of these independence criteria: + 1. **Disjoint files** β€” no two parallel steps modify the same file (compare `Files:` lists) + 2. **Independent compilation** β€” applying only that step's changes, `dotnet build` succeeds + 3. **Independent tests** β€” its related tests pass without changes from sibling parallel steps + 4. **No type coupling** β€” does not introduce a type/interface/method consumed by another parallel step + + Steps that fail any criterion are **sequential** β€” mark with *depends on step N*. Independent steps go in the same wave, marked *parallel with step N*. When unsure, treat as sequential. 6. **Save draft** β€” Save the plan to `/memories/session/plan.md` via #tool:vscode/memory immediately after drafting, **before** presenting to the user. This is a persistence checkpoint β€” the file is not a substitute for showing the plan to the user. @@ -96,33 +85,26 @@ On user input after showing the plan: ### Phase 5 β€” Execute -9. **Verify plan in memory** β€” Read `/memories/session/plan.md` via #tool:vscode/memory and confirm it matches the approved plan. If it doesn't match or is missing, re-save and verify before proceeding. If save fails, stop and return `BLOCKED_NO_PLAN_MEMORY_WRITE`. -10. **Spec** (if plan has Spec Updates) β€” Delegate to `Spec` to update specification documents listed in the plan's **Spec Updates** section. Spec changes MUST complete before any other execution step begins β€” specs define the contract that code and docs implement against. -11. **Execute** β€” Execute per the plan's **Parallelism Schedule**, processing one wave at a time. Each step specifies its agent via the `Agent:` field (`Implement`, `Doc`, or `DevOps`). +9. **Spec** (if plan has Spec Updates) β€” First confirm `/memories/session/plan.md` matches the approved version (re-save if not; on save failure return `BLOCKED_NO_PLAN_MEMORY_WRITE`). Then delegate to `Spec` to update spec documents listed in the plan's **Spec Updates** section. Spec changes MUST complete before any other execution step. +10. **Execute** β€” Process one wave at a time per the Parallelism Schedule. For each wave: + 1. Dispatch one subagent per step (using the step's `Agent:` field) **in parallel**. Each subagent receives the full plan, its assigned step number(s), the goal, and writes results to `/memories/session/changes-step-{n}.md`. The Orchestrator MUST NOT write or merge these files. + 2. Wait for all subagents to complete. On failure, re-dispatch the same agent type for just the failing step (reusing the same `changes-step-{n}.md`). + 3. Proceed to the next wave. - **For each wave (sequential across waves):** - 1. Identify all steps assigned to this wave from the Parallelism Schedule. - 2. Delegate one subagent **per step** using the agent specified in the step's `Agent:` field β€” launch all subagents for the current wave **in parallel**. Each subagent receives: the full plan, only its assigned step number(s), and the goal. - 3. Each subagent writes its results to `/memories/session/changes-step-{step_number}.md`. - 4. Wait for **all** subagents in the wave to complete. - 5. If any subagent fails β†’ fix the failure (delegate a new subagent of the same agent type for just the failing step) before proceeding. - 6. After the wave succeeds, merge all `changes-step-*.md` from this wave into `/memories/session/changes.md`. - 7. Proceed to the next wave. + **Single-step plans:** Dispatch one subagent and instruct it to write to `/memories/session/changes.md`. - **Single-step plans:** If the Parallelism Schedule has only one wave with one step, delegate a single subagent (using the step's `Agent:` field) with the entire plan. - -12. **Review** β€” Delegate to `Review` with the plan and the list of changed files (from `changes.md`). After Review completes, read `/memories/session/review.md` via #tool:vscode/memory to retrieve the structured review report. If Review finds high-severity issues, delegate back to the appropriate agent to fix, then re-review. +11. **Review** β€” Delegate to `Review` with the plan and the list of per-step changes files (`/memories/session/changes-step-*.md` for multi-step plans, or `/memories/session/changes.md` for single-step plans). After Review completes, read `/memories/session/review.md` via #tool:vscode/memory to retrieve the structured review report. If Review finds high-severity issues, delegate back to the appropriate agent to fix, then re-review. ### Phase 6 β€” Verify & Complete -13. **DocReview** (if any step used `Doc` agent) β€” Delegate to `DocReview` to verify documentation updates. -14. **Complete** β€” Summarize: +12. **DocReview** (if any step used `Doc` agent) β€” Delegate to `DocReview` to verify documentation updates. +13. **Complete** β€” Summarize: - What changed (list of files) - Test results - Review outcome - Any follow-ups or known limitations -Handle `BLOCKED_*` codes per the [memory policy](../instructions/memory-policy.instructions.md). +Handle `BLOCKED_*` codes per the [Memory Policy: BLOCKED Response Codes](../instructions/memory-policy.instructions.md#blocked-response-codes). ## Plan Format @@ -169,61 +151,20 @@ Plans saved to `/memories/session/plan.md` and presented to the user MUST follow **Plan rules:** - NO code blocks in steps β€” describe changes, reference specific symbols/functions -- NO blocking questions in the plan β€” ask them during the Alignment phase via #tool:vscode/askQuestions so all decisions are resolved before the plan is finalized -- The plan MUST be presented inline to the user β€” the plan file is for persistence only, not a substitute for showing it in conversation -- Step-by-step with explicit dependencies β€” mark which steps can run in parallel vs. which block on prior steps +- Each step MUST include `Agent:` and `Files:` fields - Reference critical architecture to reuse β€” specific functions, types, or patterns, not just file names - Explicit scope boundaries β€” what's included and what's deliberately excluded -- **Parallelism independence guarantee** β€” steps marked *parallel* MUST satisfy ALL of: - 1. **Disjoint files** β€” no two parallel steps modify the same file - 2. **Independent compilation** β€” each step's changes compile on their own (`dotnet build` succeeds) - 3. **Independent tests** β€” each step's related tests pass without changes from sibling parallel steps - 4. **No type coupling** β€” a parallel step must not introduce a type, interface, or method that another parallel step consumes +- Acceptance Criteria MUST be concrete and verifiable ## Memory Protocol -> **Goal**: `/memories/session/goal.md` β€” created in Phase 0, read-only afterwards. Provide to every subagent delegation. -> -> **Current plan**: `/memories/session/plan.md` β€” read and write exclusively via #tool:vscode/memory . - -**When to SAVE (write):** -- `/memories/session/goal.md` β€” once, in Phase 0, before Discovery -- After drafting the plan in the Design phase β€” **before** presenting to the user (persistence checkpoint) -- After the user requests changes β€” update the file to keep it in sync with the presented plan -- After approval, if the file doesn't match the approved version -- Whenever plan scope changes during execution - -**When to READ (verify):** -- Before delegating to any subagent after the initial Explore β€” confirm the plan exists and is current -- Before starting the Execute phase β€” confirm the saved plan matches the approved version -- After every save β€” read back to verify content is complete and matches intent -- After delegating to `PlanReview` β€” read `/memories/session/plan-review.md` to retrieve review findings - -**When to BLOCK:** -- If memory write or verification fails β†’ `BLOCKED_NO_PLAN_MEMORY_WRITE` -- If a subagent returns `BLOCKED_NEEDS_PARENT_PLAN` β†’ re-save/verify plan, then re-dispatch -- If a subagent returns `BLOCKED_NEEDS_PARENT_DECISION` β†’ resolve at parent level, update plan, re-dispatch - -## Boundaries - -- βœ… **Always:** - - Save `/memories/session/goal.md` before any research or delegation - - Delegate to `Explore` before drafting any plan - - Use #tool:vscode/askQuestions during Alignment to resolve ambiguities **before** finalizing the plan - - Save plan to memory immediately after drafting, before presenting to user - - Delegate to `PlanReview` after saving the draft plan, before presenting to user - - Wait for explicit user approval before execution - - Verify plan in memory before delegating to any post-Explore subagent - - Re-save plan to memory whenever scope changes - - Delegate to `Review` after every execution round - - Follow conventions from `AGENTS.md` and instruction files - - Use #tool:todo to track progress across phases - - Present the plan inline β€” never rely on the plan file as a substitute - -- 🚫 **Never:** - - Implement code directly β€” always delegate to the appropriate subagent (`Implement`, `Spec`, `Doc`, or `DevOps`) - - Skip the approval gate β€” never implement without user confirmation - - Skip the review phase β€” always delegate to `Review` after implementation - - Put blocking questions in the plan β€” ask during Alignment, not at the end - - Make large assumptions β€” use #tool:vscode/askQuestions when in doubt - - Modify secrets, CI/CD configs, or NuGet publishing settings +- `/memories/session/goal.md` β€” written once in Phase 0, read-only afterwards; provide to every subagent. +- `/memories/session/plan.md` β€” all reads/writes exclusively via #tool:vscode/memory. +- Full protocol, single-writer rules, and `BLOCKED_*` codes: see [`memory-policy.instructions.md`](../instructions/memory-policy.instructions.md). + +## Boundaries (red lines) + +- 🚫 Never implement code or edit source files directly β€” always delegate. +- 🚫 Never skip the user approval gate or the post-implementation `Review`. +- 🚫 Never modify secrets, CI/CD configs, or NuGet publishing settings. +- 🚫 Never read or write `/memories/session/*` with any tool other than #tool:vscode/memory (not even via a URI from #tool:vscode/resolveMemoryFileUri). diff --git a/.github/agents/PlanReview.agent.md b/.github/agents/PlanReview.agent.md index 5af9238..eabb4e0 100644 --- a/.github/agents/PlanReview.agent.md +++ b/.github/agents/PlanReview.agent.md @@ -2,6 +2,7 @@ description: "Use when: verifying a drafted plan against the actual codebase before presenting to user. Checks assumptions, goal achievability, architecture descriptions, and step feasibility." model: GPT-5.4 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, read, search, web, github/get_file_contents, github/issue_read, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest] +target: vscode agents: [] user-invocable: false argument-hint: "Invoked by Orchestrator after drafting and saving plan to /memories/session/plan.md. No additional input required β€” reads plan automatically." @@ -10,7 +11,7 @@ You are a plan reviewer for the SourceGen C# source generator repository. You ve Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Child Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Approach @@ -49,6 +50,7 @@ Follow the **child agent protocol** in `.github/instructions/memory-policy.instr - Run commands or tests - Modify `/memories/session/plan.md` (owned by parent agents) - Suggest scope expansions or architectural improvements β€” only report accuracy/feasibility issues + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Memory Protocol diff --git a/.github/agents/Review.agent.md b/.github/agents/Review.agent.md index 81f831f..9bf994d 100644 --- a/.github/agents/Review.agent.md +++ b/.github/agents/Review.agent.md @@ -2,6 +2,7 @@ description: "Use when: reviewing completed implementation against spec. Performs read-only code review for spec compliance, refactoring opportunities, and performance optimization." model: GPT-5.4 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, execute/getTerminalOutput, read, search, web, github/get_file_contents, github/issue_read, codegraphcontext/analyze_code_relationships, codegraphcontext/calculate_cyclomatic_complexity, codegraphcontext/find_code, codegraphcontext/find_dead_code, codegraphcontext/find_most_complex_functions, codegraphcontext/get_repository_stats, codegraphcontext/load_bundle, codegraphcontext/search_registry_bundles, codegraphcontext/visualize_graph_query, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo, github.vscode-pull-request-github/doSearch, github.vscode-pull-request-github/activePullRequest] +target: vscode agents: [] user-invocable: false argument-hint: "Provide the spec/plan and list of changed files to review" @@ -10,7 +11,7 @@ You are a senior code reviewer specializing in C# source generators. You perform Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Child Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Approach 1. **Load goal and plan from memory (MANDATORY FIRST ACTION β€” do this before anything else)**: @@ -18,11 +19,12 @@ Follow the **child agent protocol** in `.github/instructions/memory-policy.instr - If both are present and non-empty β†’ proceed to step 2. - If either is missing or empty β†’ STOP and return `BLOCKED_NEEDS_PARENT_PLAN`. - If memory tool fails β†’ STOP and return `BLOCKED_NO_PLAN_MEMORY`. -2. Read all changed/created files listed in the prompt -3. For each file, compare the implementation against the spec -4. Identify refactoring opportunities and performance concerns -5. Produce a structured review report -6. Use #tool:vscode/memory to save the review report to `/memories/session/review.md` so the parent agent can read it and decide next steps +2. Read the per-step changes files provided by the parent (`/memories/session/changes-step-*.md` for multi-step plans, or `/memories/session/changes.md` for single-step plans) via #tool:vscode/memory to identify changed/created files +3. Read all changed/created files +4. For each file, compare the implementation against the spec +5. Identify refactoring opportunities and performance concerns +6. Produce a structured review report +7. Use #tool:vscode/memory to save the review report to `/memories/session/review.md` so the parent agent can read it and decide next steps ## Review Checklist - **Spec Compliance**: Does the implementation match every requirement in the approved plan? @@ -47,6 +49,7 @@ Follow the **child agent protocol** in `.github/instructions/memory-policy.instr - Run commands or tests - Suggest changes outside the scope of the spec/plan - Modify `/memories/session/plan.md` (owned by parent agents) + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format Return a structured report in this exact format: diff --git a/.github/agents/Spec.agent.md b/.github/agents/Spec.agent.md index f5e7715..ed2fd84 100644 --- a/.github/agents/Spec.agent.md +++ b/.github/agents/Spec.agent.md @@ -1,7 +1,8 @@ --- description: "Use when: updating or creating specification documents (file under Spec/). Writes clear specs targeting both human developers and AI agents." -model: Claude Opus 4.6 (copilot) +model: Claude Sonnet 4.6 (copilot) tools: [vscode/memory, vscode/resolveMemoryFileUri, read, edit, search, web, codegraphcontext/analyze_code_relationships, codegraphcontext/find_code, codegraphcontext/get_repository_stats, 'io.github.upstash/context7/*', 'microsoftdocs/mcp/*', todo] +target: vscode agents: [] user-invocable: false argument-hint: "Implement spec updates from the approved plan stored in /memories/session/plan.md" @@ -10,7 +11,7 @@ You are a specification writer for the SourceGen C# source generator project. Yo Follow the project principles in `AGENTS.md`. -Follow the **child agent protocol** in `.github/instructions/memory-policy.instructions.md`. +Follow the **Child Workflow** in `.github/instructions/memory-policy.instructions.md`. ## Writing Guidelines @@ -84,6 +85,7 @@ RFC 8174 clarifies RFC 2119: - Add content beyond what the plan specifies - Remove existing spec content unless the plan explicitly requires it - Guess at behavior when the plan is ambiguous + - Read or write any `/memories/session/*` path with a tool other than #tool:vscode/memory (no #tool:read, #tool:edit, #tool:execute/#tool:run_in_terminal, search/grep tools, or shell commands β€” even via a URI returned by #tool:vscode/resolveMemoryFileUri). See `.github/instructions/memory-policy.instructions.md`. ## Output Format diff --git a/.github/context/ioc-architecture.context.md b/.github/context/ioc-architecture.context.md deleted file mode 100644 index d93219e..0000000 --- a/.github/context/ioc-architecture.context.md +++ /dev/null @@ -1,72 +0,0 @@ ---- -description: "IoC source generator architecture summary for quick agent context loading." -applyTo: "src/Ioc/**" ---- - -# IoC Generator Architecture - -## Overview - -Compile-time IoC container generator based on `Microsoft.Extensions.DependencyInjection.Abstractions`. Produces zero-reflection, AOT-compatible dependency injection containers via C# incremental source generation. - -## Entry Point - -`IocSourceGenerator` β€” implements `IIncrementalGenerator` in `src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/IocSourceGenerator.cs` - -## Pipeline Stages - -```text -ForAttributeWithMetadataName -β”œβ”€β”€ IocRegisterAttribute β†’ BasicRegistrationResult -β”œβ”€β”€ IocRegisterForAttribute β†’ BasicRegistrationResult -β”œβ”€β”€ IocRegisterDefaultsAttributeβ†’ DefaultSettingsResult -β”œβ”€β”€ IocImportModuleAttribute β†’ ImportModuleResult -β”œβ”€β”€ IocDiscoverAttribute β†’ ClosedGenericDependency -└── IocContainerAttribute β†’ Container generation output - ↓ - CombineAndResolve β†’ ServiceRegistrationWithTags - ↓ - RegisterSourceOutput (registration code + container code) -``` - -## Key Data Models - -| Type | Purpose | Location | -| ---- | ------- | -------- | -| `BasicRegistrationResult` | Core registration pipeline output | `Models/BasicRegistrationResult.cs` | -| `DefaultSettingsResult` | Defaults transform result | `Models/DefaultSettingsResult.cs` | -| `ImportModuleResult` | Module import transform result | `Models/ImportModuleResult.cs` | -| `ServiceRegistrationWithTags` | Merged registration with tags | `Models/ServiceRegistrationWithTags.cs` | -| `OpenGenericEntry` | Open generic service entries | `Models/BasicRegistrationResult.cs` | -| `ClosedGenericDependency` | Discovered closed generic deps | `Models/BasicRegistrationResult.cs` | - -## Public Attributes - -| Attribute | Purpose | -| --------- | ------- | -| `[IocContainer]` | Marks a partial class as the container | -| `[IocRegister]` | Register a service type directly | -| `[IocRegisterFor]` | Register implementation for specific service types | -| `[IocRegisterDefaults]` | Set default lifetime and options for registrations | -| `[IocImportModule]` | Import registrations from another assembly/module | -| `[IocDiscover]` | Auto-discover and register closed generics | -| `[IocInject]` | Mark field/property/method for injection | -| `[IocGenericFactory]` | Mark a method as generic factory provider | - -## Project Layout - -| Directory | Purpose | -| --------- | ------- | -| `src/SourceGen.Ioc/` | Public API attributes and runtime types | -| `src/SourceGen.Ioc.Cli/` | CLI tool for container visualization | -| `src/SourceGen.Ioc.SourceGenerator/Generator/` | Incremental generator pipeline | -| `src/SourceGen.Ioc.SourceGenerator/Analyzer/` | Roslyn analyzers (SGIOC diagnostics) | -| `src/SourceGen.Ioc.SourceGenerator/Models/` | Immutable data models for pipeline | -| `test/SourceGen.Ioc.Test/` | TUnit tests | -| `test/SourceGen.Ioc.TestAot/` | Native AOT validation tests | -| `test/SourceGen.Ioc.TestCase/` | Shared test case projects | - -## Specifications - -- Generator: `src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md` -- Analyzer: `src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md` diff --git a/.github/context/project-map.context.md b/.github/context/project-map.context.md deleted file mode 100644 index 0a07d65..0000000 --- a/.github/context/project-map.context.md +++ /dev/null @@ -1,65 +0,0 @@ ---- -description: "Quick project directory index for codebase navigation." ---- - -# Project Map - -## Solution: SourceGen.slnx - -```text -SourceGen/ -β”œβ”€β”€ .github/ -β”‚ β”œβ”€β”€ agents/ # Agent definitions (*.agent.md) -β”‚ β”œβ”€β”€ context/ # Context helpers (*.context.md) -β”‚ β”œβ”€β”€ instructions/ # Domain instructions (*.instructions.md) -β”‚ β”œβ”€β”€ prompts/ # Reusable workflows (*.prompt.md) -β”‚ β”œβ”€β”€ skills/ # Custom skills (SKILL.md) -β”‚ └── workflows/ # GitHub Actions CI/CD -β”œβ”€β”€ docs/ # User-facing documentation (VitePress) -β”‚ └── Ioc/ # IoC feature documentation -β”œβ”€β”€ samples/ # Sample projects for end users -β”‚ └── Ioc/ # IoC sample applications -└── src/ - β”œβ”€β”€ Ioc/ # IoC domain (see ioc-architecture.context.md) - β”‚ β”œβ”€β”€ src/ - β”‚ β”‚ β”œβ”€β”€ SourceGen.Ioc/ # Public API (attributes) - β”‚ β”‚ β”œβ”€β”€ SourceGen.Ioc.Cli/ # CLI visualization tool - β”‚ β”‚ └── SourceGen.Ioc.SourceGenerator/ # Generator + Analyzer - β”‚ └── test/ - β”‚ β”œβ”€β”€ SourceGen.Ioc.Test/ # TUnit tests - β”‚ β”œβ”€β”€ SourceGen.Ioc.TestAot/ # AOT validation - β”‚ β”œβ”€β”€ SourceGen.Ioc.TestCase/ # Shared test cases - β”‚ β”œβ”€β”€ SourceGen.Ioc.Cli.Test/ # CLI tests - β”‚ └── SourceGen.Ioc.Benchmark/ # BenchmarkDotNet - └── Validation/ # Validation domain (future) - └── src/ - β”œβ”€β”€ SourceGen.Validation/ # Public API - └── SourceGen.Validation.SourceGenerator/ # Generator -``` - -## Key Commands - -```powershell -# Build everything -dotnet build SourceGen.slnx - -# IoC tests (TUnit β€” use dotnet run, NOT dotnet test) -dotnet run --project src/Ioc/test/SourceGen.Ioc.Test/SourceGen.Ioc.Test.csproj - -# IoC AOT tests -dotnet publish src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj -c Release -dotnet run --project src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj -c Release -``` - -## Convention Quick Reference - -| Aspect | Convention | -| ------ | ---------- | -| Language | C# 14 | -| Namespaces | File-scoped | -| Nullability | `#nullable enable` always | -| Data models | `readonly record struct` or `sealed record class` | -| Generated code | `// ` header | -| Test framework | TUnit (`dotnet run`, never `dotnet test --filter`) | -| Diagnostic prefix | `SGIOC` | -| Versioning | nbgv with project-specific tag prefixes | diff --git a/.github/instructions/memory-policy.instructions.md b/.github/instructions/memory-policy.instructions.md index 7a26e68..602bac3 100644 --- a/.github/instructions/memory-policy.instructions.md +++ b/.github/instructions/memory-policy.instructions.md @@ -1,104 +1,106 @@ --- -description: "Use when executing agent workflows that coordinate through /memories/session/. Defines the memory protocol for parent and child agents." +description: "Memory protocol for agent workflows that store goal and plan state in /memories/session/." +applyTo: "**" --- # Memory Policy -All agents that participate in the planβ†’approveβ†’implementβ†’review workflow MUST follow this protocol for `/memories/session/` paths. +## Tool Usage -## Session Memory Paths +For `/memories/session/*`: -| Path | Owner | Purpose | -|------|-------|---------| -| `/memories/session/goal.md` | Parent agents (Orchestrator, Doc, DevOps) | Requirement goal β€” created before Discovery, read by all subagents | -| `/memories/session/plan.md` | Parent agents (Orchestrator, Doc, DevOps) | Approved plan β€” read by all child agents | -| `/memories/session/plan-review.md` | PlanReview agent | Structured plan review report β€” read by parent agent before presenting plan | -| `/memories/session/changes.md` | Implement agent | Changed files, decisions, issues, concerns from implementation | -| `/memories/session/review.md` | Review agent | Structured review report | +- **Read / verify**: `memory` with `view`. +- **Write**: `memory` with `create` / `str_replace` / `insert`. +- **Forbidden**: `read_file`, `grep_search`, `run_in_terminal`, `cat`, shell redirection, pipes. Same rule applies to subagents. +- `resolveMemoryFileUri` is informational only β€” never pass its URI to another tool. -## Memory Access Rules +## Memory Rules -- **ONLY** use #tool:vscode/memory (the `memory` tool) to read and write `/memories/session/` paths. -- Use #tool:vscode/resolveMemoryFileUri to resolve a `/memories/` path to a file URI when another tool requires a real file path instead of a memory abstraction. This is read-only and does not replace #tool:vscode/memory for content operations. -- Do NOT use #tool:read for `/memories/session/` paths; these are memory-only. -- Do NOT use #tool:edit for `/memories/session/` paths; these are memory-only. +Three core principles: -### Exact Tool Call Syntax +1. **Single writer** β€” each file has exactly one owner (see table). Non-owners may only read. +2. **Single command per operation** β€” one `memory` call per intent. Do not combine `create` with `str_replace` / `insert` on the same file in a single turn. A write followed by a verifying `view` is allowed. +3. **Memory tool only** β€” see the rule above. Each `memory.view` call returns the latest committed content, so concurrent readers always see the most recent verified write. -**Reading the plan** β€” use the `view` command: -``` -memory({ command: "view", path: "/memories/session/plan.md" }) -``` +Allowed commands: `view`, `create`, `str_replace`, `insert`. Parents MUST NOT dispatch two writers for the same file in parallel. -**Saving the plan** β€” use the `create` command (for new) or `str_replace` command (for updates): -``` -memory({ command: "create", path: "/memories/session/plan.md", file_text: "" }) -``` +### Session Files And Ownership -``` -memory({ command: "str_replace", path: "/memories/session/plan.md", old_str: "", new_str: "" }) -``` +| Path | Writer | Readers | Purpose | +|------|--------|---------|---------| +| `goal.md` | Parent (Orchestrator / Doc / DevOps) | All | Requirement goal | +| `plan.md` | Parent | All | Approved plan | +| `plan-review.md` | PlanReview | Parent | Plan review report | +| `changes.md` | Implement / Doc / DevOps (single-step plans only) | Parent, Review | Implementation summary | +| `changes-step-{n}.md` | Implement / Doc / DevOps (multi-step plans, one writer per step) | Parent, Review | Per-step implementation summary | +| `review.md` | Review | Parent | Code review report | -## Parent Agent Protocol +## Parent Workflow -Parent agents (Orchestrator, DevOps, Doc) create, save, and maintain the goal and plan: +Parents are Orchestrator, DevOps, and Doc. -0. **Capture Goal** β€” Before any research, distill the user's request into a concise goal statement and save it to `/memories/session/goal.md` via #tool:vscode/memory. This file is the single source of truth for *what* we are trying to achieve. Include it (or reference it) when delegating to every subagent. -1. **Explore First** β€” The first subagent call in every task MUST be `Explore` to gather context. Provide the goal from `goal.md` alongside the research question. -2. **Clarify if Needed** β€” After `Explore` returns, resolve any material ambiguity before finalizing the plan. Use #tool:vscode/askQuestions when requirements are incomplete, multiple valid approaches exist, public API or dependency changes are involved, or a user decision is required. Do not ask questions that can be answered from the codebase. -3. **Create Plan** β€” After `Explore` and any necessary clarification, create `plan.md` using the format defined by the active parent agent. The plan MUST be structured, complete, and current, and MUST include the equivalent of: goal/outcome, implementation approach or steps, scope or relevant files, and acceptance criteria or verification. -4. **Save Draft Plan** β€” After drafting the plan, and before delegating to any subagent after the initial `Explore` call, save the current plan to `/memories/session/plan.md` via #tool:vscode/memory or an update command if the file already exists. -5. **Present & Approve** β€” Present the plan inline to the user. The memory file is for persistence, not a substitute for showing the plan in conversation. Do not delegate execution subagents until the user explicitly approves. -6. **Verify Plan** β€” After every save, read back via #tool:vscode/memory and confirm the content is complete and current. -7. **Gate on Failure** β€” If memory write or verification fails, stop and return `BLOCKED_NO_PLAN_MEMORY_WRITE`. If user/system action may resolve the problem, use #tool:vscode/askQuestions to request correction before stopping. -8. **Handle Blocked Subagents** β€” If a subagent returns `BLOCKED_NEEDS_PARENT_PLAN` or `BLOCKED_NEEDS_PARENT_DECISION`, resolve at parent level, re-save/verify plan if needed, then re-dispatch the same subagent. -9. **Re-save on Scope Change** β€” If plan scope changes during refinement or execution, overwrite `/memories/session/plan.md` and verify again before any subsequent subagent delegation. +### Happy Path -## Child Agent Protocol +1. `create` `goal.md` with a concise distilled goal β€” before any research. +2. Dispatch `Explore` as the first subagent call, UNLESS the task is trivially scoped (single-file edit, typo fix, or user-supplied exact instructions). When in doubt, dispatch `Explore`. +3. Call `askQuestions` if **any one** of the following is true (OR, not AND): + - `goal.md` lacks explicit acceptance criteria. + - Two or more mutually exclusive implementation approaches exist and existing rules do not specify priority. + - The change touches a public API, a package dependency, or `.github/workflows/*`. +4. Draft `plan.md` (goal, approach/steps, scope/files, acceptance criteria) β†’ save with `create` (new) or `str_replace` (update) β†’ `view` to verify β†’ present the plan inline β†’ wait for explicit user approval. +5. Dispatch non-Explore subagents only after `plan.md` is verified AND approved. -**CRITICAL**: The VERY FIRST action of any child agent MUST be to load and validate the plan. Do NOT skip this step. Do NOT proceed to any other work until the plan is loaded and confirmed non-empty. +### Error Handling -Child agents (Implement, Review, DocReview, Spec, PlanReview) load and validate the plan: +- Save or verify of `plan.md` fails β†’ return `BLOCKED_NO_PLAN_MEMORY_WRITE`. +- Child returns any `BLOCKED_*` β†’ resolve at parent level β†’ update `plan.md` if scope changed β†’ re-dispatch. +- Scope changes mid-task β†’ overwrite and re-verify `plan.md` before the next dispatch. -1. **Load Goal and Plan (FIRST ACTION β€” mandatory, non-skippable)** β€” Call #tool:vscode/memory to read `/memories/session/goal.md` first, then `/memories/session/plan.md`, as your very first tool calls. No other tool call may precede these. -2. **Validate Content** β€” Confirm both goal and plan content are present and non-empty. If valid, proceed to work. -3. **Block if Missing** β€” If memory read fails or plan is missing/empty, stop immediately and return `BLOCKED_NEEDS_PARENT_PLAN` with a brief reason requesting the parent agent to save a complete plan. Do NOT attempt to guess the plan or proceed without it. -4. **Block on Tool Failure** β€” If memory is unavailable due to tool/runtime issues, stop and return `BLOCKED_NO_PLAN_MEMORY`. -5. **Block on Ambiguity** β€” If anything in the plan is unclear or a design decision is needed, return `BLOCKED_NEEDS_PARENT_DECISION` with the exact clarification needed. -6. **Never Ask User** β€” Never request plan content or approvals directly from the user; all requests go through the parent agent. +## Child Workflow -## BLOCKED Response Codes +Children are Implement, Review, DocReview, Spec, PlanReview. -| Code | Meaning | Who Returns | Who Resolves | -|------|---------|-------------|--------------| -| `BLOCKED_NEEDS_PARENT_PLAN` | Plan missing or empty in memory | Child agent | Parent agent saves plan, re-dispatches | -| `BLOCKED_NEEDS_PARENT_DECISION` | Plan ambiguity or design decision needed | Child agent | Parent agent clarifies, re-dispatches | -| `BLOCKED_NO_PLAN_MEMORY` | Memory tool unavailable | Any agent | User/system resolves tool issue | -| `BLOCKED_NO_PLAN_MEMORY_WRITE` | Memory write or verification failed | Parent agent | User resolves, parent retries | +### Happy Path -## Reporting Guidance +1. Sequentially call: + ``` + memory({ command: "view", path: "/memories/session/goal.md" }) + memory({ command: "view", path: "/memories/session/plan.md" }) + ``` +2. If both files are present and non-empty β†’ perform the assigned task β†’ report using the agent's defined Output Format. -If an agent definition requires a structured completion report, include plan-memory status in that report. A recommended template is: +### Error Handling + +| Condition | Action | +|---|---| +| Either file missing or empty | Return `BLOCKED_NEEDS_PARENT_PLAN` | +| Plan ambiguous or design decision needed | Return `BLOCKED_NEEDS_PARENT_DECISION` with the exact clarification needed | +| `memory` call returned an actual error | Return `BLOCKED_NO_PLAN_MEMORY` with the verbatim error | + +Children MUST NOT request plan content or approval from the user β€” route through the parent. `vscode/memory` is granted in every child's frontmatter `tools:`; you MUST attempt the call before claiming it is unavailable, and MUST NOT substitute another tool. + +## Parent β†’ Child Delegation Template ``` -#### Preconditions -- MemoryGoalLoaded: true | false -- MemoryPlanLoaded: true | false -- MemoryPlanSaved: true | false (parent agents only) -- MemoryPlanVerified: true | false (parent agents only) -- MemoryPath: /memories/session/goal.md, /memories/session/plan.md -- Blocker: (empty or BLOCKED_* code with reason) +Your frontmatter (tools:) grants #tool:vscode/memory. Follow these steps: + +1. First action β€” call: + memory({ command: "view", path: "/memories/session/goal.md" }) +2. Second action β€” call: + memory({ command: "view", path: "/memories/session/plan.md" }) +3. Then perform: +4. Report using: + +Follow the Memory Rules above. If the memory tool genuinely fails, return BLOCKED_NO_PLAN_MEMORY with the verbatim error message. ``` -If the active agent definition does not require a structured preconditions block, at minimum report any `BLOCKED_*` state clearly. +## BLOCKED Response Codes -## Boundaries +| Code | Returned by | Resolved by | +|------|-------------|-------------| +| `BLOCKED_NEEDS_PARENT_PLAN` | Child | Parent | +| `BLOCKED_NEEDS_PARENT_DECISION` | Child | Parent | +| `BLOCKED_NO_PLAN_MEMORY` | Any | User / system | +| `BLOCKED_NO_PLAN_MEMORY_WRITE` | Parent | User | -- βœ… **Always:** Access `/memories/session/` paths exclusively via #tool:vscode/memory -- βœ… **Always:** Use #tool:vscode/resolveMemoryFileUri when another tool needs a file URI for a memory path -- βœ… **Always:** Verify plan content after every save operation -- βœ… **Always:** Handle all `BLOCKED_*` responses at the appropriate level -- 🚫 **Never:** Use #tool:read for `/memories/session/` paths -- 🚫 **Never:** Use #tool:edit for `/memories/session/` paths -- 🚫 **Never:** Delegate to any subagent (after initial Explore) before saving and verifying plan -- 🚫 **Never:** Have child agents ask users directly for plan content or approvals +Surface any `BLOCKED_*` state clearly in the agent's report. diff --git a/.github/instructions/project-structure.instructions.md b/.github/instructions/project-structure.instructions.md index 48b6e00..79ee8f0 100644 --- a/.github/instructions/project-structure.instructions.md +++ b/.github/instructions/project-structure.instructions.md @@ -12,7 +12,6 @@ See also: [AGENTS.md](../../AGENTS.md) for project principles and agent topology | Directory | Purpose | Default Access | |-----------|---------|----------------| | `.github/agents/` | Agent definition files | ⚠️ Ask first | -| `.github/context/` | Context helpers for agent loading | ⚠️ Ask first | | `.github/instructions/` | Shared instruction files for agents and Copilot | ⚠️ Ask first | | `.github/prompts/` | Reusable agentic workflows | ⚠️ Ask first | | `.github/skills/` | Copilot custom skills with scripts | ⚠️ Ask first | diff --git a/.gitignore b/.gitignore index 6b62af1..df7e21e 100644 --- a/.gitignore +++ b/.gitignore @@ -240,6 +240,7 @@ ClientBin/ *.pfx *.publishsettings orleans.codegen.cs +/temp/ # Including strong name files can present a security risk # (https://github.com/github/gitignore/pull/2483#issue-259490424) diff --git a/AGENTS.md b/AGENTS.md index 29782bd..95f5e47 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,7 +18,7 @@ C# incremental source generators for compile-time code generation. - All generated code must include `// ` header and `#nullable enable` - All generated code must be deterministic β€” same input produces identical output - Generator data models must be immutable with value-based equality β€” use `sealed record class` or `readonly record struct` -- Use `PolyType.Roslyn` utilities: `SourceWriter`, `ImmutableEquatableArray`, etc. +- Use `PolyType.Roslyn` utilities: `SourceWriter`, `ImmutableEquatableArray`, `ImmutableEquatableDictionary`, `ImmutableEquatableSet`, etc. - Generator pipeline: `ForAttributeWithMetadataName` β†’ `Select` (extract data model) β†’ `Collect` β†’ `RegisterSourceOutput` - Keep diagnostic descriptors in the Analyzer project, not in the Generator - If unsure whether an issue is a design decision or a test failure, **ask the user for clarification** diff --git a/samples/Ioc/IocSample.Shared/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/SharedModule.Container.g.cs b/samples/Ioc/IocSample.Shared/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/SharedModule.Container.g.cs index 738362f..6d356b5 100644 --- a/samples/Ioc/IocSample.Shared/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/SharedModule.Container.g.cs +++ b/samples/Ioc/IocSample.Shared/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/SharedModule.Container.g.cs @@ -35,14 +35,14 @@ public SharedModule(IServiceProvider? fallbackProvider) _fallbackProvider = fallbackProvider; // Initialize eager singletons - _iocSample_Shared_TestHandler = GetIocSample_Shared_TestHandler(); - _iocSample_Shared_TestRequest2Handler = GetIocSample_Shared_TestRequest2Handler(); - _iocSample_Shared_TestRequest3Handler = GetIocSample_Shared_TestRequest3Handler(); - _iocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest__System_Collections_Generic_List_string___ = GetIocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest__System_Collections_Generic_List_string___(); - _iocSample_Shared_Logger_IocSample_Shared_TestRequest2Handler_ = GetIocSample_Shared_Logger_IocSample_Shared_TestRequest2Handler_(); - _iocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest2__System_Collections_Generic_List_string___ = GetIocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest2__System_Collections_Generic_List_string___(); - _iocSample_Shared_Logger_IocSample_Shared_TestRequest3Handler_ = GetIocSample_Shared_Logger_IocSample_Shared_TestRequest3Handler_(); - _iocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest3__int__ = GetIocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest3__int__(); + GetIocSample_Shared_TestHandler(); + GetIocSample_Shared_TestRequest2Handler(); + GetIocSample_Shared_TestRequest3Handler(); + GetIocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest__System_Collections_Generic_List_string___(); + GetIocSample_Shared_Logger_IocSample_Shared_TestRequest2Handler_(); + GetIocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest2__System_Collections_Generic_List_string___(); + GetIocSample_Shared_Logger_IocSample_Shared_TestRequest3Handler_(); + GetIocSample_Shared_Logger_IocSample_Shared_HandlerDecorator1_IocSample_Shared_TestRequest3__int__(); } private SharedModule(SharedModule parent) diff --git a/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs b/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs index 9d1262a..ae0952b 100644 --- a/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs +++ b/samples/Ioc/IocSample/Generated/SourceGen.Ioc.SourceGenerator/SourceGen.Ioc.IocSourceGenerator/Module.Container.g.cs @@ -38,16 +38,16 @@ public Module(IServiceProvider? fallbackProvider) _iocSample_Shared_SharedModule = new global::IocSample.Shared.SharedModule(fallbackProvider); // Initialize eager singletons - _iocSample_TestQueryHandler = GetIocSample_TestQueryHandler(); - _iocSample_Consumer = GetIocSample_Consumer(); - _iocSample_External = GetIocSample_External(); - _iocSample_GenericRequestHandler_IocSample_Entity_ = GetIocSample_GenericRequestHandler_IocSample_Entity_(); - _iocSample_GenericRequestHandler_IocSample_Entity3_ = GetIocSample_GenericRequestHandler_IocSample_Entity3_(); - _iocSample_GenericRequestHandler_IocSample_Entity2_ = GetIocSample_GenericRequestHandler_IocSample_Entity2_(); - _iocSample_GenericRequestHandler2_IocSample_Entity3_ = GetIocSample_GenericRequestHandler2_IocSample_Entity3_(); - _iocSample_GenericRequestHandler2_IocSample_Entity_ = GetIocSample_GenericRequestHandler2_IocSample_Entity_(); - _iocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create = GetIocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create(); - _iocSample_GenericRequestHandler2_IocSample_Entity2_ = GetIocSample_GenericRequestHandler2_IocSample_Entity2_(); + GetIocSample_TestQueryHandler(); + GetIocSample_Consumer(); + GetIocSample_External(); + GetIocSample_GenericRequestHandler_IocSample_Entity_(); + GetIocSample_GenericRequestHandler_IocSample_Entity3_(); + GetIocSample_GenericRequestHandler_IocSample_Entity2_(); + GetIocSample_GenericRequestHandler2_IocSample_Entity3_(); + GetIocSample_GenericRequestHandler2_IocSample_Entity_(); + GetIocSample_IGenericFactoryService_IocSample_IWrapper_decimal___IocSample_GenericFactory_Create(); + GetIocSample_GenericRequestHandler2_IocSample_Entity2_(); // Initialize Lazy wrapper fields _lazy_IocSample_IWrapperService_int__IocSample_WrapperService_int_ = new global::System.Lazy>(() => GetIocSample_WrapperService_int_(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication); diff --git a/src/Ioc/AGENTS.md b/src/Ioc/AGENTS.md index 9562a0e..ca98aff 100644 --- a/src/Ioc/AGENTS.md +++ b/src/Ioc/AGENTS.md @@ -36,7 +36,7 @@ dotnet publish src/Ioc/test/SourceGen.Ioc.TestAot/SourceGen.Ioc.TestAot.csproj - | Spec | Scope | | --- | --- | -| [Generator SPEC](src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md) | Index β€” registration, container generation, all sub-specs | +| [Generator SPEC](src/SourceGen.Ioc.SourceGenerator/Spec/SPEC.spec.md) | Index β€” registration, container generation, all sub-specs | | [Analyzer SPEC](src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md) | Diagnostic rules and analyzers | ## Domain Rules diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs index a974aa4..a6ad4c0 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/AnalyzerHelpers.cs @@ -162,29 +162,6 @@ public static bool IsAnyIoCAttribute(INamedTypeSymbol attributeClass, IocAttribu || IsIoCRegisterDefaultsAttribute(attributeClass, attributeSymbols); } - /// - /// Extracts registered service types from an IoC registration attribute's ServiceTypes property. - /// - /// The attribute to extract service types from. - /// An enumerable of service type symbols. - public static IEnumerable GetServiceTypesFromAttribute(AttributeData attribute) - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key is not "ServiceTypes") - continue; - - if(namedArg.Value.Kind is not TypedConstantKind.Array) - continue; - - foreach(var element in namedArg.Value.Values) - { - if(element.Value is INamedTypeSymbol serviceType) - yield return serviceType; - } - } - } - /// /// Returns when the implementation contains at least one async inject method. /// Mirrors the generator's async-init classification by looking for instance ordinary methods @@ -272,8 +249,7 @@ public int GetHashCode((INamedTypeSymbol ServiceType, string? Key) obj) /// public static IEnumerable EnumerateRegisteredServiceTypes( INamedTypeSymbol implementationType, - AttributeData attribute, - IocAttributeSymbols attributeSymbols) + AttributeData attribute) { yield return implementationType; @@ -281,16 +257,10 @@ public static IEnumerable EnumerateRegisteredServiceTypes( if(attrClass is null) yield break; - if(IsIoCRegisterAttribute(attrClass, attributeSymbols) && attrClass.IsGenericType) - { - foreach(var typeArg in attrClass.TypeArguments) - { - if(typeArg is INamedTypeSymbol serviceType) - yield return serviceType; - } - } + foreach(var serviceType in attribute.GetServiceTypeSymbolsFromGenericAttribute()) + yield return serviceType; - foreach(var serviceType in GetServiceTypesFromAttribute(attribute)) + foreach(var serviceType in attribute.GetServiceTypeSymbols()) yield return serviceType; var (_, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); @@ -392,73 +362,14 @@ public static bool IsNonGenericTaskType(ITypeSymbol? type) { type = UnwrapNullableValueType(type); - if(type is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 1 } namedType) + if(type is null + || !type.TryGetWrapperInfo(out var wrapperInfo)) return null; - if(namedType.ContainingNamespace.ToDisplayString() != "System.Threading.Tasks") + if(wrapperInfo.Kind is not WrapperKind.Task and not WrapperKind.ValueTask) return null; - if(namedType.Name is not ("Task" or "ValueTask")) - return null; - - return namedType.TypeArguments[0] as INamedTypeSymbol; - } - - /// - /// Unwraps any Generator-supported wrapper type to extract the inner service type for partial accessor resolution. - /// Supported wrappers: Task<T>, Lazy<T>, Func<T>, IEnumerable<T>, IReadOnlyCollection<T>, ICollection<T>, - /// IReadOnlyList<T>, IList<T>, T[], IDictionary<K,V>, IReadOnlyDictionary<K,V>, Dictionary<K,V>, KeyValuePair<K,V>. - /// - /// The inner service type, or null if the type is not a recognized wrapper. - public static INamedTypeSymbol? TryUnwrapWrapperElementType(ITypeSymbol type) - { - // Array: T[] - if(type.TypeKind == TypeKind.Array) - return (type as IArrayTypeSymbol)?.ElementType as INamedTypeSymbol; - - if(type is not INamedTypeSymbol named) - return null; - - // Arity-1 wrappers - if(named.Arity == 1) - { - var typeArg = named.TypeArguments[0] as INamedTypeSymbol; - - // IEnumerable - if(named.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) - return typeArg; - - var ns = named.ContainingNamespace.ToDisplayString(); - - if(ns == "System.Collections.Generic" - && named.Name is "IReadOnlyCollection" or "ICollection" or "IReadOnlyList" or "IList") - return typeArg; - - if(ns == "System" && named.Name == "Lazy") - return typeArg; - - if(ns == "System.Threading.Tasks" && named.Name is "Task" or "ValueTask") - return typeArg; - } - - // Arity-2 wrappers β€” return the value type (TypeArguments[1]) - if(named.Arity == 2) - { - var ns = named.ContainingNamespace.ToDisplayString(); - - if(ns == "System.Collections.Generic" - && named.Name is "IDictionary" or "IReadOnlyDictionary" or "Dictionary" or "KeyValuePair") - return named.TypeArguments[1] as INamedTypeSymbol; - } - - // Func, Func, Func, etc. - // The last type argument is always the return type (the service type to resolve). - if(named.Arity >= 1 - && named.ContainingNamespace.ToDisplayString() == "System" - && named.Name == "Func") - return named.TypeArguments[named.TypeArguments.Length - 1] as INamedTypeSymbol; - - return null; + return wrapperInfo.ElementType; } /// @@ -478,13 +389,6 @@ public static bool IsUnsupportedPartialAccessorReturnType(ITypeSymbol type) return named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; } - /// - /// Returns true when the type is ValueTask<T> (generic, arity 1). - /// - public static bool IsGenericValueTaskType(ITypeSymbol type) - => type is INamedTypeSymbol { Name: "ValueTask", Arity: 1 } named - && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; - /// /// Returns when is a direct Task<T> /// for the specified service type. @@ -714,45 +618,3 @@ public static (bool IsInvalid, string? Reason) GetRegistrationInvalidReason(INam } } -/// -/// Holds cached IoC attribute type symbols for efficient comparison in analyzers. -/// -internal sealed class IocAttributeSymbols -{ - public INamedTypeSymbol? IocContainerAttribute { get; } - public INamedTypeSymbol? IocRegisterAttribute { get; } - public INamedTypeSymbol? IocRegisterAttribute_T1 { get; } - public INamedTypeSymbol? IocRegisterForAttribute { get; } - public INamedTypeSymbol? IocRegisterForAttribute_T1 { get; } - public INamedTypeSymbol? IocRegisterDefaultsAttribute { get; } - public INamedTypeSymbol? IocRegisterDefaultsAttribute_T1 { get; } - public INamedTypeSymbol? IocImportModuleAttribute { get; } - public INamedTypeSymbol? IocImportModuleAttribute_T1 { get; } - - public IocAttributeSymbols(Compilation compilation) - { - IocContainerAttribute = compilation.GetTypeByMetadataName(Constants.IocContainerAttributeFullName); - IocRegisterAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterAttributeFullName); - IocRegisterAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocRegisterAttributeFullName_T1); - IocRegisterForAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterForAttributeFullName); - IocRegisterForAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocRegisterForAttributeFullName_T1); - IocRegisterDefaultsAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterDefaultsAttributeFullName); - IocRegisterDefaultsAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocRegisterDefaultsAttributeFullName_T1); - IocImportModuleAttribute = compilation.GetTypeByMetadataName(Constants.IocImportModuleAttributeFullName); - IocImportModuleAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocImportModuleAttributeFullName_T1); - } - - /// - /// Checks if any IoC registration attribute is available in the compilation. - /// - public bool HasAnyRegistrationAttribute => - IocRegisterAttribute is not null - || IocRegisterAttribute_T1 is not null - || IocRegisterForAttribute is not null - || IocRegisterForAttribute_T1 is not null; - - /// - /// Checks if the IocContainerAttribute is available. - /// - public bool HasContainerAttribute => IocContainerAttribute is not null; -} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs index d59301a..9503d9e 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/ContainerAnalyzer.cs @@ -519,6 +519,8 @@ when property.IsPartialDefinition continue; var serviceType = GetAccessorServiceType(normalizedReturnType); + var isGenericValueTaskReturnType = normalizedReturnType.TryGetWrapperInfo(out var returnWrapperInfo) + && returnWrapperInfo.Kind is WrapperKind.ValueTask; // Guard: if the innermost type (ignoring downgrade rules) is an async-init service, // all diagnostic reporting is owned by AnalyzeAsyncPartialAccessors (SGIOC027/029). @@ -550,7 +552,7 @@ when property.IsPartialDefinition var location = member.Locations.FirstOrDefault(); // For ValueTask (not a generator-supported recursive wrapper), report the full return type // because the shape is considered downgraded/unsupported per spec. - var diagnosticType = AnalyzerHelpers.IsGenericValueTaskType(normalizedReturnType) + var diagnosticType = isGenericValueTaskReturnType ? normalizedReturnType : serviceType; context.ReportDiagnostic(Diagnostic.Create( @@ -566,7 +568,7 @@ when property.IsPartialDefinition if(serviceType is not null && !isNullable && IsPartialAccessorServiceRegistered(serviceType, serviceKey, analyzerContext) - && AnalyzerHelpers.IsGenericValueTaskType(normalizedReturnType)) + && isGenericValueTaskReturnType) { var location = member.Locations.FirstOrDefault(); context.ReportDiagnostic(Diagnostic.Create( @@ -589,8 +591,7 @@ private static void RegisterServiceTypes( foreach(var serviceType in AnalyzerHelpers.EnumerateRegisteredServiceTypes( implementationType, - attribute, - analyzerContext.AttributeSymbols)) + attribute)) { analyzerContext.RegisteredServiceTypes.TryAdd(serviceType, true); analyzerContext.Registrations.Add(new ServiceRegistration(serviceType, serviceKey, implementationType)); @@ -612,57 +613,33 @@ private static void RegisterServiceTypes( // ValueTask is not a generator-supported recursive wrapper; return T directly so callers // can check registration and async-init. SGIOC029 / SGIOC021 are reported separately. - if(AnalyzerHelpers.IsGenericValueTaskType(normalizedReturnType)) - return AnalyzerHelpers.TryUnwrapWrapperElementType(normalizedReturnType); + if(normalizedReturnType.TryGetWrapperInfo(out var rootWrapperInfo) + && rootWrapperInfo.Kind is WrapperKind.ValueTask) + return rootWrapperInfo.ElementType; // Recursively unwrap generator-supported wrappers with downgrade detection. // Mirrors TransformExtensions.cs downgrade rules (nested Task shapes, collection-at-top). ITypeSymbol current = normalizedReturnType; - var isFirst = true; + var isAfterCollection = false; - while(AnalyzerHelpers.TryUnwrapWrapperElementType(current) is { } element) + while(current.TryGetWrapperInfo(out var wrapperInfo) + && wrapperInfo.ElementType is { } element) { - // Downgrade rule 1: Task β€” outer is Task AND inner type is itself a wrapper - if(IsGenericTask(current) && AnalyzerHelpers.TryUnwrapWrapperElementType(element) is not null) - return null; + var elementKind = element.TryGetWrapperInfo(out var elementWrapperInfo) + ? elementWrapperInfo.Kind + : WrapperKind.None; - // Downgrade rule 2: Wrapper β€” outer is non-Task wrapper AND inner type is Task - if(!IsGenericTask(current) && IsGenericTask(element)) + if(elementKind is not WrapperKind.None + && IsUnsupportedWrapperNesting(wrapperInfo.Kind, elementKind, isAfterCollection)) return null; - // Downgrade rule 3: ValueTask encountered during recursion (at top level it is handled above) - if(!isFirst && AnalyzerHelpers.IsGenericValueTaskType(current)) - return null; - - // Downgrade rule 4: Collection-at-top β€” outermost is a collection AND inner is a non-collection wrapper - if(isFirst && IsCollectionWrapper(current) && AnalyzerHelpers.TryUnwrapWrapperElementType(element) is not null && !IsCollectionWrapper(element)) - return null; + if(wrapperInfo.Kind.IsCollectionWrapperKind()) + isAfterCollection = true; current = element; - isFirst = false; } return current as INamedTypeSymbol; - - static bool IsGenericTask(ITypeSymbol type) - => type is INamedTypeSymbol { Name: "Task", Arity: 1 } named - && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; - - static bool IsCollectionWrapper(ITypeSymbol type) - { - if(type.TypeKind == TypeKind.Array) - return true; - - if(type is not INamedTypeSymbol named || named.Arity != 1) - return false; - - if(named.OriginalDefinition.SpecialType == SpecialType.System_Collections_Generic_IEnumerable_T) - return true; - - var ns = named.ContainingNamespace.ToDisplayString(); - return ns == "System.Collections.Generic" - && named.Name is "IReadOnlyCollection" or "ICollection" or "IReadOnlyList" or "IList"; - } } /// @@ -674,7 +651,8 @@ static bool IsCollectionWrapper(ITypeSymbol type) { ITypeSymbol current = input; - while(AnalyzerHelpers.TryUnwrapWrapperElementType(current) is { } element) + while(current.TryGetWrapperInfo(out var wrapperInfo) + && wrapperInfo.ElementType is { } element) { current = element; } @@ -853,23 +831,10 @@ private static bool IsPartialAccessorServiceRegistered( if(attrClass is null) return null; - // Non-generic form: [IocImportModule(typeof(T))] - if(AnalyzerHelpers.IsAttributeMatch(attrClass, attributeSymbols.IocImportModuleAttribute)) - { - return attr.ConstructorArguments.Length > 0 - ? attr.ConstructorArguments[0].Value as INamedTypeSymbol - : null; - } - - // Generic form: [IocImportModule] β€” OriginalDefinition comparison is handled inside IsAttributeMatch - if(AnalyzerHelpers.IsAttributeMatch(attrClass, attributeSymbols.IocImportModuleAttribute_T1)) - { - return attrClass.IsGenericType && attrClass.TypeArguments.Length > 0 - ? attrClass.TypeArguments[0] as INamedTypeSymbol - : null; - } + if(!AnalyzerHelpers.IsIocImportModuleAttribute(attrClass, attributeSymbols)) + return null; - return null; + return attr.GetImportedModuleType(); } /// diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/IocAttributeSymbols.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/IocAttributeSymbols.cs new file mode 100644 index 0000000..108f038 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/IocAttributeSymbols.cs @@ -0,0 +1,44 @@ +namespace SourceGen.Ioc; + +/// +/// Holds cached IoC attribute type symbols for efficient comparison in analyzers. +/// +internal sealed class IocAttributeSymbols +{ + public INamedTypeSymbol? IocContainerAttribute { get; } + public INamedTypeSymbol? IocRegisterAttribute { get; } + public INamedTypeSymbol? IocRegisterAttribute_T1 { get; } + public INamedTypeSymbol? IocRegisterForAttribute { get; } + public INamedTypeSymbol? IocRegisterForAttribute_T1 { get; } + public INamedTypeSymbol? IocRegisterDefaultsAttribute { get; } + public INamedTypeSymbol? IocRegisterDefaultsAttribute_T1 { get; } + public INamedTypeSymbol? IocImportModuleAttribute { get; } + public INamedTypeSymbol? IocImportModuleAttribute_T1 { get; } + + public IocAttributeSymbols(Compilation compilation) + { + IocContainerAttribute = compilation.GetTypeByMetadataName(Constants.IocContainerAttributeFullName); + IocRegisterAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterAttributeFullName); + IocRegisterAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocRegisterAttributeFullName_T1); + IocRegisterForAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterForAttributeFullName); + IocRegisterForAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocRegisterForAttributeFullName_T1); + IocRegisterDefaultsAttribute = compilation.GetTypeByMetadataName(Constants.IocRegisterDefaultsAttributeFullName); + IocRegisterDefaultsAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocRegisterDefaultsAttributeFullName_T1); + IocImportModuleAttribute = compilation.GetTypeByMetadataName(Constants.IocImportModuleAttributeFullName); + IocImportModuleAttribute_T1 = compilation.GetTypeByMetadataName(Constants.IocImportModuleAttributeFullName_T1); + } + + /// + /// Checks if any IoC registration attribute is available in the compilation. + /// + public bool HasAnyRegistrationAttribute => + IocRegisterAttribute is not null + || IocRegisterAttribute_T1 is not null + || IocRegisterForAttribute is not null + || IocRegisterForAttribute_T1 is not null; + + /// + /// Checks if the IocContainerAttribute is available. + /// + public bool HasContainerAttribute => IocContainerAttribute is not null; +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.DependencyAnalysis.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.DependencyAnalysis.cs index 74612b8..7be3a89 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.DependencyAnalysis.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.DependencyAnalysis.cs @@ -1,6 +1,3 @@ -using System.Collections.Concurrent; -using Microsoft.CodeAnalysis.Diagnostics; - namespace SourceGen.Ioc; /// diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs index d42df3f..888c312 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.ServiceCollection.cs @@ -43,7 +43,7 @@ private static ImmutableHashSet CollectAssemblyLevelRegistrations( var (hasFactory, hasInstance) = attribute.HasFactoryOrInstance(); var assemblyServiceTypes = AnalyzerHelpers.EnumerateRegisteredServiceTypes( - targetType, attribute, analyzerContext.AttributeSymbols).ToList(); + targetType, attribute).ToList(); RegisterServiceWithIndex(analyzerContext, targetType, lifetime, location, key, keyTypeSymbol, key is not null, hasFactory, hasInstance, assemblyServiceTypes); } @@ -149,7 +149,7 @@ private static void CollectAndValidateNamedType(SymbolAnalysisContext context, A // Register service with index for faster lookup // Dependency analysis will be done in CompilationEnd after all services are collected var serviceTypes = AnalyzerHelpers.EnumerateRegisteredServiceTypes( - targetType, attribute, analyzerContext.AttributeSymbols).ToList(); + targetType, attribute).ToList(); RegisterServiceWithIndex(analyzerContext, targetType, currentLifetime, location, key, keyTypeSymbol, key is not null, hasFactory, hasInstance, serviceTypes); } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.UnresolvableMembers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.UnresolvableMembers.cs index 4ab77e8..f3c6a3c 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.UnresolvableMembers.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.UnresolvableMembers.cs @@ -1,5 +1,3 @@ -using Microsoft.CodeAnalysis.Diagnostics; - namespace SourceGen.Ioc; /// diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs index cdcfcb0..c1ecabc 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/RegisterAnalyzer.cs @@ -506,7 +506,7 @@ private static void AnalyzeAllDependencies(CompilationAnalysisContext context, A hasRegistrationAttribute = true; var (serviceKey, _, _) = attribute.GetKeyInfo(); - foreach(var serviceType in AnalyzerHelpers.EnumerateRegisteredServiceTypes(implementationType, attribute, attributeSymbols)) + foreach(var serviceType in AnalyzerHelpers.EnumerateRegisteredServiceTypes(implementationType, attribute)) { serviceTypes.Add((serviceType, serviceKey)); } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md index 83459ce..55dce5b 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Analyzer/Spec/SPEC.spec.md @@ -1,5 +1,89 @@ # Ioc Register Analyzer +## Pipeline Architecture + +### Entry Points + +| Analyzer | Role | File | +| :------- | :--- | :--- | +| `RegisterAnalyzer` | Registration validation, dependency/lifetime checks, key diagnostics | `Analyzer/RegisterAnalyzer.cs` | +| `ContainerAnalyzer` | Container class validation, resolvability, module cycle detection | `Analyzer/ContainerAnalyzer.cs` | + +Both are `[DiagnosticAnalyzer(LanguageNames.CSharp)]` with lifecycle: `Initialize` β†’ `CompilationStart` β†’ actions β†’ `CompilationEnd`. + +### RegisterAnalyzer + +#### CompilationStart Setup + +1. Parse `SourceGenIocFeatures` via `ParseIocFeatures` +2. Queue SGIOC026 if `AsyncMethodInject` enabled without `MethodInject` +3. Build `IocAttributeSymbols` cache +4. Early exit if no registration attributes present +5. Build concurrent state: `RegisteredServices`, `ServiceTypeIndex`, `DuplicatedDefaults`, `SeenDefaultTargetTypes`, `DefaultSettingsMap` +6. Construct `AnalyzerContext` +7. Pre-collect assembly-level `IocRegisterFor` attributes + +#### Registered Actions + +| Action Type | Handler | Purpose | Diagnostics | +| :---------- | :------ | :------ | :---------- | +| Symbol (NamedType) | `CollectAndValidateNamedType` | Collect registrations, validate target type/instance lifetime, track duplicates | SGIOC001, 009, 011 | +| Symbol (NamedType) | `AnalyzeTypeLevelDefaultsAttribute` | Validate duplicate type-level defaults | SGIOC012 | +| Symbol (Property/Field/Method) | `AnalyzeInjectAttribute` | Validate `[IocInject]` usage and feature gating | SGIOC007, 022, 028 | +| Symbol (Method) | `AnalyzeDuplicatedKeyedServiceAttributes` | Duplicate keyed-service parameter attributes | SGIOC006 | +| SyntaxNode (Attribute) | `AnalyzeFactoryAndInstanceOnAttribute` | Factory/Instance member references, InjectMembers validation | SGIOC008, 010, 016, 017, 023, 024 | +| SyntaxNode (Attribute) | `ResolveCsharpKeyTypes` | Resolve nameof-based key type symbols | Enables SGIOC013/015 | +| Symbol (Method) | `AnalyzeIocGenericFactoryAttribute` | Duplicate placeholders in generic factory mapping | SGIOC017 | +| SemanticModel | `AnalyzeAssemblyLevelRegistrations` | Assembly-level register-for IDE squiggles | SGIOC001, 009, 011 | +| CompilationEnd | `AnalyzeAllDependencies` | Full-graph dependency/key/async-init analysis | SGIOC002-005, 012-015, 030 | + +#### CompilationEnd: AnalyzeAllDependencies + +1. Report buffered duplicated defaults (SGIOC012) +2. Per `ServiceInfo`: key type mismatch (SGIOC013/014), KVP key mismatch (SGIOC015), dependency DFS (SGIOC002-005) +3. Build async-init-only pairs β†’ sync dependency analysis (SGIOC030) + +#### Partial Files + +| File | Responsibility | +| :--- | :------------- | +| `RegisterAnalyzer.ServiceCollection.cs` | `AnalyzerContext`, `ServiceInfo`, service collection/indexing, defaults, assembly-level scan | +| `RegisterAnalyzer.DependencyAnalysis.cs` | Circular dependency DFS, lifetime conflicts, `TryUnwrapServiceType` (private wrapper unwrap) | +| `RegisterAnalyzer.DuplicatedRegistration.cs` | Duplicate registration/defaults detection (SGIOC011, 012) | +| `RegisterAnalyzer.AttributeUsage.cs` | Inject/factory/instance/generic factory/inject-members validation | +| `RegisterAnalyzer.UnresolvableMembers.cs` | ServiceKey type matching, KeyValuePair/Dictionary key analysis | + +### ContainerAnalyzer + +#### CompilationStart Setup + +1. Parse `SourceGenIocFeatures` +2. Build `IocAttributeSymbols` +3. Early exit if `IocContainerAttribute` unavailable +4. Build `ContainerAnalyzerContext`: `RegisteredServiceTypes`, `ServiceImplementationTypes`, `Registrations`, `ContainersWithNoFallback`, `AllContainers`, `ImportEdges` +5. Collect assembly-level register-for registrations + +#### Registered Actions + +| Action Type | Handler | Purpose | Diagnostics | +| :---------- | :------ | :------ | :---------- | +| Symbol (NamedType) | `AnalyzeContainerClass` | Container partial/static validation, import edges, UseSwitchStatement | SGIOC019, 020 | +| Symbol (NamedType) | `CollectRegisteredServices` | Collect registrations from IocRegister/IocRegisterFor | Feeds 018, 021, 027, 029 | +| CompilationEnd | `AnalyzeContainerDependencies` | Async accessor contracts, container dependency resolvability | SGIOC018, 021, 027, 029 | +| CompilationEnd | `AnalyzeCircularImports` | Module import cycle detection via DFS | SGIOC025 | + +#### CompilationEnd: AnalyzeContainerDependencies + +Pass 1 (all containers): `AnalyzeAsyncPartialAccessors` β†’ SGIOC027 (sync accessor on async-init), SGIOC029 (unsupported async return) +Pass 2 (IntegrateServiceProvider=false only): `AnalyzeContainerServiceDependencies` (SGIOC018), `AnalyzePartialAccessorDependencies` (SGIOC021) + +### Diagnostic Coverage + +| Analyzer | Diagnostics | +| :------- | :---------- | +| RegisterAnalyzer | SGIOC001, 002, 003, 004, 005, 006, 007, 008, 009, 010, 011, 012, 013, 014, 015, 016, 017, 022, 023, 024, 026, 028, 030 | +| ContainerAnalyzer | SGIOC018, 019, 020, 021, 025, 027, 029 | + ## Diagnostics Format: ID - Level - Category - Description diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerAsyncResolverWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerAsyncResolverWriters.cs new file mode 100644 index 0000000..84e446d --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerAsyncResolverWriters.cs @@ -0,0 +1,37 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + // ────────────────────────────────────────────────────────────────────────────── + + /// + /// Returns the async routing resolver method name by appending "Async" to the sync method name. + /// + private static string GetAsyncResolverMethodName(string syncMethodName) + => syncMethodName + "Async"; + + /// + /// Returns the async creation method name (e.g. "CreateFooBarAsync" from "GetFooBar"). + /// + private static string GetAsyncCreateMethodName(string syncMethodName) + { + if(syncMethodName.Length > 3 && syncMethodName.StartsWith("Get", StringComparison.Ordinal)) + return "Create" + syncMethodName[3..] + "Async"; + return syncMethodName + "_CreateAsync"; + } + + /// + /// Returns the effective thread-safety strategy for a registration. + /// Async-init services auto-upgrade async-incompatible strategies to . + /// + private static ThreadSafeStrategy GetEffectiveThreadSafeStrategy( + ThreadSafeStrategy strategy, + bool isAsyncInit) + { + if(!isAsyncInit) + return strategy; + + return strategy is ThreadSafeStrategy.None ? ThreadSafeStrategy.None : ThreadSafeStrategy.SemaphoreSlim; + } + +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerEntry.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerEntry.cs new file mode 100644 index 0000000..5678d28 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerEntry.cs @@ -0,0 +1,983 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + internal abstract record class ContainerEntry + { + public virtual void WriteField(SourceWriter writer) + { + } + + public abstract void WriteResolver(SourceWriter writer); + + public virtual void WriteEagerInit(SourceWriter writer) + { + } + + public virtual void WriteDisposal(SourceWriter writer, bool isAsync) + { + } + + public virtual void WriteInit(SourceWriter writer) + { + } + + public virtual void WriteCollectionResolver(SourceWriter writer) + { + } + + public virtual void WriteLocalResolverEntries(SourceWriter writer) + { + } + } + + private abstract record class ServiceContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ContainerEntry + { + protected void WriteInstanceCreationWithInjection(SourceWriter writer) + { + var typeDeclaration = Decorators.Length > 0 ? Registration.ServiceType.Name : "var"; + + if(Registration.Factory is not null) + { + var factoryCall = BuildFactoryCallExpression(); + writer.WriteLine($"{typeDeclaration} instance = ({Registration.ImplementationType.Name}){factoryCall};"); + return; + } + + WriteConstructorWithInjection(writer); + } + + protected void WriteConstructorWithInjection(SourceWriter writer) + { + var typeDeclaration = Decorators.Length > 0 ? Registration.ServiceType.Name : "var"; + + var properties = new List(); + var methods = new List(); + + foreach(var member in InjectionMembers) + { + switch(member.Member.MemberType) + { + case InjectionMemberType.Property: + case InjectionMemberType.Field: + properties.Add(member); + break; + + case InjectionMemberType.Method: + methods.Add(member); + break; + } + } + + WriteConstructorWithPropertyInitializers(writer, "instance", typeDeclaration, ConstructorParameters, properties); + + foreach(var method in methods) + { + var methodArgs = BuildMethodArguments(method, Registration.Key); + writer.WriteLine($"instance.{method.Member.Name}({methodArgs});"); + } + } + + protected void WriteDecoratorApplication(SourceWriter writer) + { + for(var i = Decorators.Length - 1; i >= 0; i--) + { + var decorator = Decorators[i]; + var decoratorName = decorator.Decorator.ImplementationType.Name; + var hasInjectionMembers = decorator.InjectionMembers.Length > 0; + + var args = new List(1 + decorator.Parameters.Length) + { + "instance" + }; + + foreach(var parameter in decorator.Parameters) + { + args.Add(parameter.Dependency.FormatExpression(parameter.IsOptional)); + } + + var argsString = string.Join(", ", args); + + if(hasInjectionMembers) + { + var decoratorVarName = $"decorator{Decorators.Length - 1 - i}"; + WriteDecoratorCreationWithInjection(writer, decoratorVarName, decoratorName, decorator.InjectionMembers, argsString, Registration.Key); + writer.WriteLine($"instance = {decoratorVarName};"); + } + else + { + writer.WriteLine($"instance = new {decoratorName}({argsString});"); + } + } + } + + protected void WriteAsyncInstanceCreationBody(SourceWriter writer) + { + var properties = new List(); + var syncMethods = new List(); + var asyncMethods = new List(); + + foreach(var member in InjectionMembers) + { + switch(member.Member.MemberType) + { + case InjectionMemberType.Property: + case InjectionMemberType.Field: + properties.Add(member); + break; + + case InjectionMemberType.Method: + syncMethods.Add(member); + break; + + case InjectionMemberType.AsyncMethod: + asyncMethods.Add(member); + break; + } + } + + var hasMethods = syncMethods.Count > 0 || asyncMethods.Count > 0; + var hasDecorators = Decorators.Length > 0; + var needsTwoVarPattern = hasDecorators && hasMethods; + + var injectionVar = needsTwoVarPattern ? "baseInstance" : "instance"; + var typeDeclaration = hasDecorators && !needsTwoVarPattern ? Registration.ServiceType.Name : "var"; + + if(Registration.Factory is not null) + { + var factoryCall = BuildFactoryCallExpression(); + writer.WriteLine($"{typeDeclaration} {injectionVar} = ({Registration.ImplementationType.Name}){factoryCall};"); + } + else + { + WriteConstructorWithPropertyInitializers(writer, injectionVar, typeDeclaration, ConstructorParameters, properties); + } + + foreach(var method in syncMethods) + { + var methodArgs = BuildMethodArguments(method, Registration.Key); + writer.WriteLine($"{injectionVar}.{method.Member.Name}({methodArgs});"); + } + + foreach(var method in asyncMethods) + { + var methodArgs = BuildMethodArguments(method, Registration.Key); + writer.WriteLine($"await {injectionVar}.{method.Member.Name}({methodArgs});"); + } + + if(hasDecorators) + { + writer.WriteLine(); + if(needsTwoVarPattern) + { + writer.WriteLine($"{Registration.ServiceType.Name} instance = {injectionVar};"); + } + + WriteDecoratorApplication(writer); + } + + writer.WriteLine("return instance;"); + } + + private string BuildFactoryCallExpression() + { + var factory = Registration.Factory!; + + var args = new List(); + if(factory.HasServiceProvider) + { + args.Add("this"); + } + + if(factory.HasKey && Registration.Key is not null) + { + args.Add(Registration.Key); + } + + foreach(var parameter in ConstructorParameters) + { + args.Add(parameter.Dependency.FormatExpression(parameter.IsOptional)); + } + + var genericTypeArgs = BuildGenericFactoryTypeArgs(factory, Registration.ServiceType); + var factoryCallPath = genericTypeArgs is not null ? $"{factory.Path}<{genericTypeArgs}>" : factory.Path; + return $"{factoryCallPath}({string.Join(", ", args)})"; + } + + private void WriteConstructorWithPropertyInitializers( + SourceWriter writer, + string variableName, + string typeDeclaration, + ImmutableEquatableArray constructorParameters, + List propertyMembers) + { + var args = constructorParameters.Length == 0 + ? "" + : string.Join(", ", constructorParameters.Select(static p => p.Dependency.FormatExpression(p.IsOptional))); + + if(propertyMembers.Count == 0) + { + writer.WriteLine($"{typeDeclaration} {variableName} = new {Registration.ImplementationType.Name}({args});"); + return; + } + + writer.WriteLine($"{typeDeclaration} {variableName} = new {Registration.ImplementationType.Name}({args})"); + writer.WriteLine("{"); + writer.Indentation++; + + foreach(var member in propertyMembers) + { + if(member.Dependency is null) + throw new InvalidOperationException($"Missing resolved dependency for injection member '{member.Member.Name}'."); + + writer.WriteLine($"{member.Member.Name} = {member.Dependency.FormatExpression(member.Member.IsNullable)},"); + } + + writer.Indentation--; + writer.WriteLine("};"); + } + + private static void WriteDecoratorCreationWithInjection( + SourceWriter writer, + string variableName, + string decoratorTypeName, + ImmutableEquatableArray injectionMembers, + string argsString, + string? registrationKey) + { + if(injectionMembers.Length == 0) + { + writer.WriteLine($"var {variableName} = new {decoratorTypeName}({argsString});"); + return; + } + + var propertyAssignments = new List(); + var methodInvocations = new List(); + + foreach(var member in injectionMembers) + { + switch(member.Member.MemberType) + { + case InjectionMemberType.Property: + case InjectionMemberType.Field: + if(member.Dependency is null) + throw new InvalidOperationException($"Missing resolved dependency for decorator member '{member.Member.Name}'."); + + propertyAssignments.Add($"{member.Member.Name} = {member.Dependency.FormatExpression(member.Member.IsNullable)},"); + break; + + case InjectionMemberType.Method: + var args = BuildMethodArguments(member, registrationKey); + methodInvocations.Add($"{variableName}.{member.Member.Name}({args});"); + break; + } + } + + if(propertyAssignments.Count == 0) + { + writer.WriteLine($"var {variableName} = new {decoratorTypeName}({argsString});"); + } + else + { + writer.WriteLine($"var {variableName} = new {decoratorTypeName}({argsString})"); + writer.WriteLine("{"); + writer.Indentation++; + foreach(var assignment in propertyAssignments) + { + writer.WriteLine(assignment); + } + + writer.Indentation--; + writer.WriteLine("};"); + } + + foreach(var invocation in methodInvocations) + { + writer.WriteLine(invocation); + } + } + + private static string BuildMethodArguments(ResolvedInjectionMember member, string? registrationKey) + { + var parameters = member.Member.Parameters; + if(parameters is null or { Length: 0 }) + return ""; + + var args = new string[parameters.Length]; + var parameterDependencies = member.ParameterDependencies; + + for(var i = 0; i < parameters.Length; i++) + { + var parameter = parameters[i]; + + if(parameter.HasServiceKeyAttribute) + { + args[i] = registrationKey ?? "null"; + continue; + } + + if(parameter.Type.Name is IServiceProviderTypeName or IServiceProviderGlobalTypeName) + { + args[i] = "this"; + continue; + } + + if(parameterDependencies.Length > i) + { + args[i] = parameterDependencies[i].FormatExpression(parameter.IsOptional); + continue; + } + + if(parameters.Length == 1 && member.Dependency is not null) + { + args[i] = member.Dependency.FormatExpression(parameter.IsOptional); + continue; + } + + throw new InvalidOperationException($"Missing resolved dependency for injection method parameter '{parameter.Name}' in member '{member.Member.Name}'."); + } + + return string.Join(", ", args); + } + } + + private sealed record class InstanceContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ServiceContainerEntry(Registration, ResolverMethodName, ConstructorParameters, InjectionMembers, Decorators) + { + public override void WriteField(SourceWriter writer) + { + } + + public override void WriteResolver(SourceWriter writer) + { + } + } + + private sealed record class EagerContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + string FieldName, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ServiceContainerEntry(Registration, ResolverMethodName, ConstructorParameters, InjectionMembers, Decorators) + { + public override void WriteField(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + writer.WriteLine($"private {returnType} {FieldName} = null!;"); + } + + public override void WriteResolver(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + + writer.WriteLine($"private {returnType} {ResolverMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteEarlyReturnIfNotNull(FieldName); + writer.WriteLine(); + + WriteInstanceCreationWithInjection(writer); + + if(Decorators.Length > 0) + { + writer.WriteLine(); + WriteDecoratorApplication(writer); + } + + writer.WriteLine(); + writer.WriteFieldAssignAndReturn(FieldName, "instance"); + + writer.Indentation--; + writer.WriteLine("}"); + } + + public override void WriteEagerInit(SourceWriter writer) + { + writer.WriteLine($"{ResolverMethodName}();"); + } + + public override void WriteDisposal(SourceWriter writer, bool isAsync) + { + var disposeMethod = isAsync ? "await DisposeServiceAsync" : "DisposeService"; + writer.WriteLine($"{disposeMethod}({FieldName});"); + } + } + + private sealed record class LazyThreadSafeContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + string FieldName, + ThreadSafeStrategy ThreadSafeStrategy, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ServiceContainerEntry(Registration, ResolverMethodName, ConstructorParameters, InjectionMembers, Decorators) + { + public override void WriteField(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + writer.WriteLine($"private {returnType}? {FieldName};"); + + var syncFieldDeclaration = ThreadSafeStrategy switch + { + ThreadSafeStrategy.Lock => $"private readonly Lock {FieldName}Lock = new();", + ThreadSafeStrategy.SemaphoreSlim => $"private readonly SemaphoreSlim {FieldName}Semaphore = new(1, 1);", + ThreadSafeStrategy.SpinLock => $"private SpinLock {FieldName}SpinLock = new(false);", + _ => null + }; + + if(syncFieldDeclaration is not null) + { + writer.WriteLine(syncFieldDeclaration); + } + } + + public override void WriteResolver(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + + writer.WriteLine($"private {returnType} {ResolverMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteEarlyReturnIfNotNull(FieldName); + writer.WriteLine(); + + Action writeResolverBody = ThreadSafeStrategy switch + { + ThreadSafeStrategy.None => WriteResolverBodyNone, + ThreadSafeStrategy.Lock => WriteResolverBodyLock, + ThreadSafeStrategy.SemaphoreSlim => WriteResolverBodySemaphoreSlim, + ThreadSafeStrategy.SpinLock => WriteResolverBodySpinLock, + ThreadSafeStrategy.CompareExchange => WriteResolverBodyCompareExchange, + _ => WriteResolverBodyNone + }; + + writeResolverBody(writer); + + writer.Indentation--; + writer.WriteLine("}"); + } + + public override void WriteDisposal(SourceWriter writer, bool isAsync) + { + var disposeMethod = isAsync ? "await DisposeServiceAsync" : "DisposeService"; + writer.WriteLine($"{disposeMethod}({FieldName});"); + + if(ThreadSafeStrategy == ThreadSafeStrategy.SemaphoreSlim) + { + writer.WriteLine($"{FieldName}Semaphore.Dispose();"); + } + } + + private void WriteResolverBodyNone(SourceWriter writer) + { + WriteInstanceCreationAndAssignment(writer); + } + + private void WriteResolverBodyLock(SourceWriter writer) + { + writer.WriteLine($"lock({FieldName}Lock)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteEarlyReturnIfNotNull(FieldName); + writer.WriteLine(); + + WriteInstanceCreationAndAssignment(writer); + + writer.Indentation--; + writer.WriteLine("}"); + } + + private void WriteResolverBodySemaphoreSlim(SourceWriter writer) + { + writer.WriteLine($"{FieldName}Semaphore.Wait();"); + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteEarlyReturnIfNotNull(FieldName); + writer.WriteLine(); + + WriteInstanceCreationAndAssignment(writer); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("finally"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine($"{FieldName}Semaphore.Release();"); + writer.Indentation--; + writer.WriteLine("}"); + } + + private void WriteResolverBodySpinLock(SourceWriter writer) + { + writer.WriteLine("bool lockTaken = false;"); + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"{FieldName}SpinLock.Enter(ref lockTaken);"); + writer.WriteEarlyReturnIfNotNull(FieldName); + writer.WriteLine(); + + WriteInstanceCreationAndAssignment(writer); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("finally"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine($"if(lockTaken) {FieldName}SpinLock.Exit();"); + writer.Indentation--; + writer.WriteLine("}"); + } + + private void WriteResolverBodyCompareExchange(SourceWriter writer) + { + WriteInstanceCreationWithInjection(writer); + + if(Decorators.Length > 0) + { + writer.WriteLine(); + WriteDecoratorApplication(writer); + } + + writer.WriteLine(); + writer.WriteLine($"var existing = Interlocked.CompareExchange(ref {FieldName}, instance, null);"); + writer.WriteLine("if(existing is not null)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("DisposeService(instance);"); + writer.WriteLine("return existing;"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("return instance;"); + } + + private void WriteInstanceCreationAndAssignment(SourceWriter writer) + { + WriteInstanceCreationWithInjection(writer); + + if(Decorators.Length > 0) + { + writer.WriteLine(); + WriteDecoratorApplication(writer); + } + + writer.WriteLine(); + writer.WriteFieldAssignAndReturn(FieldName, "instance"); + } + } + + private sealed record class TransientContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ServiceContainerEntry(Registration, ResolverMethodName, ConstructorParameters, InjectionMembers, Decorators) + { + public override void WriteResolver(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + + writer.WriteLine($"private {returnType} {ResolverMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + if(InjectionMembers.Length == 0 && Decorators.Length == 0) + { + if(Registration.Factory is not null) + { + var factory = Registration.Factory; + var args = new List(); + + if(factory.HasServiceProvider) + { + args.Add("this"); + } + + if(factory.HasKey && Registration.Key is not null) + { + args.Add(Registration.Key); + } + + foreach(var parameter in ConstructorParameters) + { + args.Add(parameter.Dependency.FormatExpression(parameter.IsOptional)); + } + + var genericTypeArgs = BuildGenericFactoryTypeArgs(factory, Registration.ServiceType); + var factoryCallPath = genericTypeArgs is not null ? $"{factory.Path}<{genericTypeArgs}>" : factory.Path; + writer.WriteLine($"return ({Registration.ImplementationType.Name}){factoryCallPath}({string.Join(", ", args)});"); + } + else + { + var ctorArgs = ConstructorParameters.Length == 0 + ? "" + : string.Join(", ", ConstructorParameters.Select(static p => p.Dependency.FormatExpression(p.IsOptional))); + writer.WriteLine($"return new {Registration.ImplementationType.Name}({ctorArgs});"); + } + + writer.Indentation--; + writer.WriteLine("}"); + return; + } + + WriteInstanceCreationWithInjection(writer); + + if(Decorators.Length > 0) + { + writer.WriteLine(); + WriteDecoratorApplication(writer); + } + + writer.WriteLine("return instance;"); + + writer.Indentation--; + writer.WriteLine("}"); + } + } + + private sealed record class AsyncContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + string FieldName, + bool IsEager, + ThreadSafeStrategy ThreadSafeStrategy, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ServiceContainerEntry(Registration, ResolverMethodName, ConstructorParameters, InjectionMembers, Decorators) + { + public override void WriteField(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; + + writer.WriteLine($"private {taskReturnType}? {FieldName};"); + + if(ThreadSafeStrategy == ThreadSafeStrategy.SemaphoreSlim) + { + writer.WriteLine($"private readonly global::System.Threading.SemaphoreSlim {FieldName}Semaphore = new(1, 1);"); + } + } + + public override void WriteResolver(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + var asyncMethodName = GetAsyncResolverMethodName(ResolverMethodName); + var createMethodName = GetAsyncCreateMethodName(ResolverMethodName); + var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; + + writer.WriteLine($"private async {taskReturnType} {asyncMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"if({FieldName} is not null)"); + writer.Indentation++; + writer.WriteLine($"return await {FieldName};"); + writer.Indentation--; + writer.WriteLine(); + + if(ThreadSafeStrategy == ThreadSafeStrategy.SemaphoreSlim) + { + WriteAsyncResolverBodySemaphoreSlim(writer, createMethodName); + } + else + { + WriteAsyncResolverBodyNone(writer, createMethodName); + } + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine($"private async {taskReturnType} {createMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + WriteAsyncInstanceCreationBody(writer); + + writer.Indentation--; + writer.WriteLine("}"); + } + + public override void WriteEagerInit(SourceWriter writer) + { + if(IsEager) + { + writer.WriteLine($"_ = {GetAsyncResolverMethodName(ResolverMethodName)}();"); + } + } + + public override void WriteDisposal(SourceWriter writer, bool isAsync) + { + var disposeMethod = isAsync ? "await DisposeServiceAsync" : "DisposeService"; + writer.WriteLine($"{disposeMethod}({FieldName});"); + + if(ThreadSafeStrategy == ThreadSafeStrategy.SemaphoreSlim) + { + writer.WriteLine($"{FieldName}Semaphore.Dispose();"); + } + } + + private void WriteAsyncResolverBodyNone(SourceWriter writer, string createMethodName) + { + writer.WriteLine($"{FieldName} = {createMethodName}();"); + writer.WriteLine($"return await {FieldName};"); + } + + private void WriteAsyncResolverBodySemaphoreSlim(SourceWriter writer, string createMethodName) + { + writer.WriteLine($"await {FieldName}Semaphore.WaitAsync();"); + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"if({FieldName} is null)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine($"{FieldName} = {createMethodName}();"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("finally"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine($"{FieldName}Semaphore.Release();"); + writer.Indentation--; + writer.WriteLine("}"); + + writer.WriteLine($"return await {FieldName};"); + } + } + + private sealed record class AsyncTransientContainerEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators) : ServiceContainerEntry(Registration, ResolverMethodName, ConstructorParameters, InjectionMembers, Decorators) + { + public override void WriteResolver(SourceWriter writer) + { + var returnType = Decorators.Length > 0 ? Registration.ServiceType.Name : Registration.ImplementationType.Name; + var createMethodName = GetAsyncCreateMethodName(ResolverMethodName); + var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; + + writer.WriteLine($"private async {taskReturnType} {createMethodName}()"); + writer.WriteLine("{"); + writer.Indentation++; + + WriteAsyncInstanceCreationBody(writer); + + writer.Indentation--; + writer.WriteLine("}"); + } + } + + private sealed record class LazyWrapperContainerEntry( + string InnerServiceTypeName, + string InnerImplTypeName, + string FieldName, + string InnerResolverMethodName, + string? Key, + bool EmitCollectionResolver, + ImmutableEquatableArray CollectionFieldNames) : ContainerEntry + { + public override void WriteField(SourceWriter writer) + { + writer.WriteLine($"private readonly global::System.Lazy<{InnerServiceTypeName}> {FieldName};"); + } + + public override void WriteResolver(SourceWriter writer) + { + } + + public override void WriteInit(SourceWriter writer) + { + writer.WriteLine($"{FieldName} = new global::System.Lazy<{InnerServiceTypeName}>(() => {InnerResolverMethodName}(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication);"); + } + + public override void WriteCollectionResolver(SourceWriter writer) + { + if(!EmitCollectionResolver) + { + return; + } + + var wrapperTypeName = $"global::System.Lazy<{InnerServiceTypeName}>"; + var arrayMethodName = GetLazyArrayResolverMethodName(InnerServiceTypeName); + + writer.WriteLine($"private {wrapperTypeName}[] {arrayMethodName}() =>"); + writer.Indentation++; + writer.WriteLine("["); + writer.Indentation++; + + foreach(var fieldName in CollectionFieldNames) + { + writer.WriteLine($"{fieldName},"); + } + + writer.Indentation--; + writer.WriteLine("];"); + writer.Indentation--; + } + + public override void WriteLocalResolverEntries(SourceWriter writer) + { + if(!EmitCollectionResolver) + { + return; + } + + var wrapperTypeName = $"global::System.Lazy<{InnerServiceTypeName}>"; + var arrayMethodName = GetLazyArrayResolverMethodName(InnerServiceTypeName); + + writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}), {KeyedServiceAnyKey}), static c => c.{FieldName}),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}[]), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + } + } + + private sealed record class FuncWrapperContainerEntry( + string InnerServiceTypeName, + string InnerImplTypeName, + string FieldName, + string InnerResolverMethodName, + string? Key, + bool EmitCollectionResolver, + ImmutableEquatableArray CollectionFieldNames) : ContainerEntry + { + public override void WriteField(SourceWriter writer) + { + writer.WriteLine($"private readonly global::System.Func<{InnerServiceTypeName}> {FieldName};"); + } + + public override void WriteResolver(SourceWriter writer) + { + } + + public override void WriteInit(SourceWriter writer) + { + writer.WriteLine($"{FieldName} = new global::System.Func<{InnerServiceTypeName}>(() => {InnerResolverMethodName}());"); + } + + public override void WriteCollectionResolver(SourceWriter writer) + { + if(!EmitCollectionResolver) + { + return; + } + + var wrapperTypeName = $"global::System.Func<{InnerServiceTypeName}>"; + var arrayMethodName = GetFuncArrayResolverMethodName(InnerServiceTypeName); + + writer.WriteLine($"private {wrapperTypeName}[] {arrayMethodName}() =>"); + writer.Indentation++; + writer.WriteLine("["); + writer.Indentation++; + + foreach(var fieldName in CollectionFieldNames) + { + writer.WriteLine($"{fieldName},"); + } + + writer.Indentation--; + writer.WriteLine("];"); + writer.Indentation--; + } + + public override void WriteLocalResolverEntries(SourceWriter writer) + { + if(!EmitCollectionResolver) + { + return; + } + + var wrapperTypeName = $"global::System.Func<{InnerServiceTypeName}>"; + var arrayMethodName = GetFuncArrayResolverMethodName(InnerServiceTypeName); + + writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}), {KeyedServiceAnyKey}), static c => c.{FieldName}),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}[]), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + } + } + + private sealed record class KvpWrapperContainerEntry( + string KeyTypeName, + string ValueTypeName, + string KeyExpr, + string ResolverMethodName, + string KvpResolverMethodName) : ContainerEntry + { + public override void WriteResolver(SourceWriter writer) + { + var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{KeyTypeName}, {ValueTypeName}>"; + writer.WriteLine($"private {kvpTypeName} {KvpResolverMethodName}() => new {kvpTypeName}({KeyExpr}, {ResolverMethodName}());"); + } + + public override void WriteLocalResolverEntries(SourceWriter writer) + { + var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{KeyTypeName}, {ValueTypeName}>"; + var arrayMethodName = GetKvpArrayResolverMethodName(KeyTypeName, ValueTypeName); + var dictionaryMethodName = GetKvpDictionaryResolverMethodName(KeyTypeName, ValueTypeName); + + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof({kvpTypeName}[]), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyDictionary<{KeyTypeName}, {ValueTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IDictionary<{KeyTypeName}, {ValueTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.Dictionary<{KeyTypeName}, {ValueTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); + } + } + + private sealed record class CollectionContainerEntry( + string ElementServiceTypeName, + string ArrayMethodName, + ImmutableEquatableArray ElementResolvers) : ContainerEntry + { + public override void WriteResolver(SourceWriter writer) + { + writer.WriteLine($"private {ElementServiceTypeName}[] {ArrayMethodName}() =>"); + writer.Indentation++; + writer.WriteLine("["); + writer.Indentation++; + + foreach(var elementResolver in ElementResolvers) + { + writer.WriteLine($"{elementResolver.FormatExpression(false)},"); + } + + writer.Indentation--; + writer.WriteLine("];"); + writer.Indentation--; + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerInjectionWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerInjectionWriters.cs new file mode 100644 index 0000000..bec0dea --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerInjectionWriters.cs @@ -0,0 +1,30 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Categorizes injection members into properties/fields and methods. + /// + private static (List? Properties, List? Methods) CategorizeInjectionMembers( + ImmutableEquatableArray injectionMembers) + { + List? properties = null; + List? methods = null; + + foreach(var member in injectionMembers) + { + if(member.MemberType is InjectionMemberType.Property or InjectionMemberType.Field) + { + properties ??= []; + properties.Add(member); + } + else if(member.MemberType == InjectionMemberType.Method) + { + methods ??= []; + methods.Add(member); + } + } + + return (properties, methods); + } +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerInterfaceWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerInterfaceWriters.cs new file mode 100644 index 0000000..7ebdba0 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerInterfaceWriters.cs @@ -0,0 +1,1245 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + + /// + /// Writes IServiceProvider implementation. + /// + private static void WriteIServiceProviderImplementation( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region IServiceProvider"); + writer.WriteLine(); + + writer.WriteLine("public object? GetService(Type serviceType)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine(); + + if(effectiveUseSwitchStatement) + { + // Built-in services (switch mode only - in dictionary mode these are in _localResolvers) + writer.WriteLine("if(serviceType == typeof(IServiceProvider)) return this;"); + writer.WriteLine("if(serviceType == typeof(IServiceScopeFactory)) return this;"); + writer.WriteLine($"if(serviceType == typeof({container.ClassName})) return this;"); + writer.WriteLine(); + + // Cascading if statements - registrations already filtered (no open generics) + foreach(var mapping in EnumerateInterfaceServiceResolvers(groups)) + { + if(mapping.Key is not null) + continue; + + writer.WriteLine($"if(serviceType == typeof({mapping.ServiceType})) return {mapping.ResolverExpression};"); + } + + writer.WriteLine(); + } + else + { + writer.WriteLine($"if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, {KeyedServiceAnyKey}), out var resolver))"); + writer.Indentation++; + writer.WriteLine("return resolver(this);"); + writer.Indentation--; + writer.WriteLine(); + } + + // Fallback + if(container.IntegrateServiceProvider) + { + writer.WriteLine("return _fallbackProvider?.GetService(serviceType);"); + } + else + { + writer.WriteLine("return null;"); + } + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes IKeyedServiceProvider implementation. + /// + private static void WriteIKeyedServiceProviderImplementation( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region IKeyedServiceProvider"); + writer.WriteLine(); + + writer.WriteLine("public object? GetKeyedService(Type serviceType, object? serviceKey)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine(); + + writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); + writer.WriteLine(); + + if(effectiveUseSwitchStatement) + { + // Use pre-computed hasKeyedServices flag + if(groups.HasKeyedServices) + { + // Use tuple pattern matching switch expression + writer.WriteLine("return (serviceType, key) switch"); + writer.WriteLine("{"); + writer.Indentation++; + + foreach(var mapping in EnumerateKeyedServiceResolvers(groups)) + { + if(mapping.Key is null) + continue; + + writer.WriteLine($"(Type t, object k) when t == typeof({mapping.ServiceType}) && Equals(k, {mapping.Key}) => {mapping.ResolverExpression},"); + } + + // Fallback in switch default case + if(container.IntegrateServiceProvider) + { + writer.WriteLine("_ => _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null"); + } + else + { + writer.WriteLine("_ => null"); + } + + writer.Indentation--; + writer.WriteLine("};"); + } + else + { + // No keyed services, just return fallback + if(container.IntegrateServiceProvider) + { + writer.WriteLine("return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null;"); + } + else + { + writer.WriteLine("return null;"); + } + } + } + else + { + writer.WriteLine("if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver))"); + writer.Indentation++; + writer.WriteLine("return resolver(this);"); + writer.Indentation--; + writer.WriteLine(); + + // Fallback + if(container.IntegrateServiceProvider) + { + writer.WriteLine("return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null;"); + } + else + { + writer.WriteLine("return null;"); + } + } + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // GetRequiredKeyedService + writer.WriteLine("public object GetRequiredKeyedService(Type serviceType, object? serviceKey)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine("return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($\"No service for type '{serviceType}' with key '{serviceKey}' has been registered.\");"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes ISupportRequiredService implementation. + /// + private static void WriteISupportRequiredServiceImplementation(SourceWriter writer, ContainerModel container) + { + writer.WriteLine("#region ISupportRequiredService"); + writer.WriteLine(); + + writer.WriteLine("public object GetRequiredService(Type serviceType)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine("return GetService(serviceType) ?? throw new InvalidOperationException($\"No service for type '{serviceType}' has been registered.\");"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes generic service resolution extension methods matching + /// ServiceProviderServiceExtensions and ServiceProviderKeyedServiceExtensions signatures. + /// In dictionary mode, methods directly query _serviceResolvers for optimal performance. + /// In switch mode, methods delegate to non-generic counterparts. + /// + private static void WriteServiceProviderExtensions( + SourceWriter writer, + ContainerModel container, + bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region ServiceProvider Extensions"); + writer.WriteLine(); + + if(effectiveUseSwitchStatement) + { + WriteServiceProviderExtensionsSwitchMode(writer); + } + else + { + WriteServiceProviderExtensionsDictionaryMode(writer); + } + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes generic extension methods for switch/if-cascade mode. + /// These delegate to the non-generic methods. + /// + private static void WriteServiceProviderExtensionsSwitchMode(SourceWriter writer) + { + // GetService + writer.WriteLine("public T? GetService() where T : class"); + writer.Indentation++; + writer.WriteLine("=> GetService(typeof(T)) as T;"); + writer.Indentation--; + writer.WriteLine(); + + // GetRequiredService + writer.WriteLine("public T GetRequiredService() where T : notnull"); + writer.Indentation++; + writer.WriteLine("=> (T)GetRequiredService(typeof(T));"); + writer.Indentation--; + writer.WriteLine(); + + // GetServices + writer.WriteLine("public System.Collections.Generic.IEnumerable GetServices()"); + writer.Indentation++; + writer.WriteLine("=> (System.Collections.Generic.IEnumerable?)GetService(typeof(System.Collections.Generic.IEnumerable)) ?? [];"); + writer.Indentation--; + writer.WriteLine(); + + // GetKeyedService + writer.WriteLine("public T? GetKeyedService(object? serviceKey) where T : class"); + writer.Indentation++; + writer.WriteLine("=> GetKeyedService(typeof(T), serviceKey) as T;"); + writer.Indentation--; + writer.WriteLine(); + + // GetRequiredKeyedService + writer.WriteLine("public T GetRequiredKeyedService(object? serviceKey) where T : notnull"); + writer.Indentation++; + writer.WriteLine("=> (T)GetRequiredKeyedService(typeof(T), serviceKey);"); + writer.Indentation--; + writer.WriteLine(); + + // GetKeyedServices + writer.WriteLine("public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey)"); + writer.Indentation++; + writer.WriteLine("=> (System.Collections.Generic.IEnumerable?)GetKeyedService(typeof(System.Collections.Generic.IEnumerable), serviceKey) ?? [];"); + writer.Indentation--; + writer.WriteLine(); + } + + /// + /// Writes generic extension methods for dictionary mode. + /// These directly query _serviceResolvers FrozenDictionary for optimal performance. + /// + private static void WriteServiceProviderExtensionsDictionaryMode(SourceWriter writer) + { + // GetService + writer.WriteLine("public T? GetService() where T : class"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), {KeyedServiceAnyKey}), out var resolver)"); + writer.Indentation++; + writer.WriteLine("? resolver(this) as T"); + writer.WriteLine(": null;"); + writer.Indentation--; + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // GetRequiredService + writer.WriteLine("public T GetRequiredService() where T : notnull"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), {KeyedServiceAnyKey}), out var resolver)"); + writer.Indentation++; + writer.WriteLine("? (T)resolver(this)"); + writer.WriteLine(": throw new InvalidOperationException($\"No service for type '{typeof(T)}' has been registered.\");"); + writer.Indentation--; + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // GetServices + writer.WriteLine("public System.Collections.Generic.IEnumerable GetServices()"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), {KeyedServiceAnyKey}), out var resolver)"); + writer.Indentation++; + writer.WriteLine("? (System.Collections.Generic.IEnumerable)resolver(this)"); + writer.WriteLine(": [];"); + writer.Indentation--; + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // GetKeyedService + writer.WriteLine("public T? GetKeyedService(object? serviceKey) where T : class"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); + writer.WriteLine("return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver)"); + writer.Indentation++; + writer.WriteLine("? resolver(this) as T"); + writer.WriteLine(": null;"); + writer.Indentation--; + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // GetRequiredKeyedService + writer.WriteLine("public T GetRequiredKeyedService(object? serviceKey) where T : notnull"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); + writer.WriteLine("return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver)"); + writer.Indentation++; + writer.WriteLine("? (T)resolver(this)"); + writer.WriteLine(": throw new InvalidOperationException($\"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered.\");"); + writer.Indentation--; + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // GetKeyedServices + writer.WriteLine("public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); + writer.WriteLine("return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver)"); + writer.Indentation++; + writer.WriteLine("? (System.Collections.Generic.IEnumerable)resolver(this)"); + writer.WriteLine(": [];"); + writer.Indentation--; + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + } + + /// + /// Writes IServiceProviderIsService implementation. + /// + private static void WriteIServiceProviderIsServiceImplementation( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region IServiceProviderIsService"); + writer.WriteLine(); + + // IsService + writer.WriteLine("public bool IsService(Type serviceType)"); + writer.WriteLine("{"); + writer.Indentation++; + + if(effectiveUseSwitchStatement) + { + // Built-in services (switch mode only - in dictionary mode these are in _localResolvers) + writer.WriteLine("if(serviceType == typeof(IServiceProvider)) return true;"); + writer.WriteLine("if(serviceType == typeof(IServiceScopeFactory)) return true;"); + writer.WriteLine($"if(serviceType == typeof({container.ClassName})) return true;"); + writer.WriteLine(); + + foreach(var serviceType in groups.AllServiceTypes) + { + writer.WriteLine($"if(serviceType == typeof({serviceType})) return true;"); + } + writer.WriteLine(); + } + else + { + writer.WriteLine($"if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, {KeyedServiceAnyKey}))) return true;"); + writer.WriteLine(); + } + + if(container.IntegrateServiceProvider) + { + writer.WriteLine("return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType);"); + } + else + { + writer.WriteLine("return false;"); + } + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // IsKeyedService + writer.WriteLine("public bool IsKeyedService(Type serviceType, object? serviceKey)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); + + if(!effectiveUseSwitchStatement) + { + writer.WriteLine(); + writer.WriteLine("if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true;"); + } + + writer.WriteLine(); + + if(container.IntegrateServiceProvider) + { + writer.WriteLine("return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey);"); + } + else + { + writer.WriteLine("return false;"); + } + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes IServiceScopeFactory implementation. + /// + private static void WriteIServiceScopeFactoryImplementation(SourceWriter writer, ContainerModel container) + { + writer.WriteLine("#region IServiceScopeFactory"); + writer.WriteLine(); + + writer.WriteLine("public IServiceScope CreateScope()"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ThrowIfDisposed();"); + writer.WriteLine($"return new {container.ClassName}(this);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("public AsyncServiceScope CreateAsyncScope() => new(CreateScope());"); + writer.WriteLine(); + writer.WriteLine("IServiceProvider IServiceScope.ServiceProvider => this;"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes IIocContainer implementation. + /// + private static void WriteIIocContainerImplementation( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region IIocContainer"); + writer.WriteLine(); + + if(effectiveUseSwitchStatement) + { + writer.WriteLine($"public static IReadOnlyCollection>> Resolvers => _localResolvers;"); + } + else + { + writer.WriteLine($"public static IReadOnlyCollection>> Resolvers => _serviceResolvers;"); + } + writer.WriteLine(); + + // Write _localResolvers as static field + writer.WriteLine($"private static readonly KeyValuePair>[] _localResolvers ="); + writer.WriteLine("["); + writer.Indentation++; + + // Built-in services: IServiceProvider, IServiceScopeFactory, and the container itself + writer.WriteLine($"new(new ServiceIdentifier(typeof(IServiceProvider), {KeyedServiceAnyKey}), static c => c),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(IServiceScopeFactory), {KeyedServiceAnyKey}), static c => c),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof({container.ContainerTypeName}), {KeyedServiceAnyKey}), static c => c),"); + + var entryByResolverMethodName = BuildEntryByResolverMethodName(groups); + + // Use service type resolver mappings (already filtered for non-open-generics) + foreach(var mapping in EnumerateLocalResolverMappings(groups)) + { + var keyExpr = mapping.Key ?? KeyedServiceAnyKey; + + string resolverExpr; + if(entryByResolverMethodName.TryGetValue(mapping.ResolverMethodName, out var entry)) + { + resolverExpr = GetResolverExpression(entry); + } + else + { + // Fallback path for compatibility with any unmapped resolver entries. + resolverExpr = $"static c => c.{mapping.ResolverMethodName}()"; + } + + writer.WriteLine($"new(new ServiceIdentifier(typeof({mapping.ServiceType}), {keyExpr}), {resolverExpr}),"); + } + + // Add IEnumerable, IReadOnlyCollection, ICollection, IReadOnlyList, IList, T[] entries for collection service types + foreach(var serviceType in groups.CollectionServiceTypes) + { + var methodName = GetArrayResolverMethodName(serviceType); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); + writer.WriteLine($"new(new ServiceIdentifier(typeof({serviceType}[]), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); + } + + // Add local resolvers for wrapper entries (KVP first, then Lazy, then Func) + WriteWrapperLocalResolverEntries(writer, groups.WrapperEntries); + + writer.Indentation--; + writer.WriteLine("];"); + + // Write _serviceResolvers as static field (only when not using switch statement) + if(!effectiveUseSwitchStatement) + { + writer.WriteLine(); + if(container.ImportedModules.Length > 0) + { + // Combine with imported modules - wrap module resolvers to pass the correct module instance + // Use static access (module.Name is the fully qualified type name) for static abstract Resolvers + writer.WriteLine($"private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers ="); + writer.Indentation++; + + var isFirst = true; + foreach(var module in container.ImportedModules) + { + var fieldName = GetModuleFieldName(module.Name); + var source = $"{module.Name}.Resolvers.Select(static kvp => new KeyValuePair>(kvp.Key, c => kvp.Value(c.{fieldName})))"; + + if(isFirst) + { + writer.WriteLine(source); + isFirst = false; + } + else + { + writer.WriteLine($".Concat({source})"); + } + } + + writer.WriteLine(".Concat(_localResolvers)"); + writer.WriteLine(".ToFrozenDictionary();"); + writer.Indentation--; + } + else + { + writer.WriteLine($"private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary();"); + } + } + + writer.WriteLine(); + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes IServiceProviderFactory implementation. + /// + private static void WriteIServiceProviderFactoryImplementation(SourceWriter writer, ContainerModel container) + { + writer.WriteLine("#region IServiceProviderFactory"); + writer.WriteLine(); + + writer.WriteLine("/// "); + writer.WriteLine("/// Creates a new container builder (returns the same IServiceCollection)."); + writer.WriteLine("/// "); + writer.WriteLine("public IServiceCollection CreateBuilder(IServiceCollection services) => services;"); + writer.WriteLine(); + + writer.WriteLine("/// "); + writer.WriteLine("/// Creates the service provider from the built IServiceCollection."); + writer.WriteLine("/// "); + writer.WriteLine("public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder);"); + writer.WriteLine($"return new {container.ClassName}(fallbackProvider);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes Disposal implementation (IDisposable and IAsyncDisposable). + /// + private static void WriteDisposalImplementation( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups) + { + writer.WriteLine("#region Disposal"); + writer.WriteLine(); + + // IDisposable.Dispose + writer.WriteLine("public void Dispose()"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(Interlocked.Exchange(ref _disposed, 1) != 0) return;"); + writer.WriteLine(); + WriteDisposalBody(writer, container, groups, isAsync: false); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // IAsyncDisposable.DisposeAsync + writer.WriteLine("public async ValueTask DisposeAsync()"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(Interlocked.Exchange(ref _disposed, 1) != 0) return;"); + writer.WriteLine(); + WriteDisposalBody(writer, container, groups, isAsync: true); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + var hasAsyncInitServices = groups.SingletonEntries.Any(static e => e is AsyncContainerEntry) + || groups.ScopedEntries.Any(static e => e is AsyncContainerEntry); + WriteDisposalHelperMethods(writer, hasAsyncInitServices); + + writer.WriteLine("#endregion"); + } + + /// + /// Writes IControllerActivator implementation using ActivatorUtilities with ObjectFactory caching. + /// + private static void WriteIControllerActivatorImplementation(SourceWriter writer, ContainerModel container) + { + writer.WriteLine("#region IControllerActivator"); + writer.WriteLine(); + + // Static cache field + writer.WriteLine("private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _controllerFactoryCache = new();"); + writer.WriteLine("private static global::Microsoft.Extensions.DependencyInjection.ObjectFactory CreateControllerFactory("); + writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers("); + writer.WriteLine(" global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] global::System.Type t)"); + writer.WriteLine(" => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateFactory(t, global::System.Type.EmptyTypes);"); + writer.WriteLine(); + + // Create method + writer.WriteLine("object global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator.Create(global::Microsoft.AspNetCore.Mvc.ControllerContext controllerContext)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(controllerContext);"); + writer.WriteLine("var controllerType = controllerContext.ActionDescriptor.ControllerTypeInfo.AsType();"); + writer.WriteLine("var instance = GetService(controllerType);"); + writer.WriteLine("if (instance is not null) return instance;"); + writer.WriteLine(); + writer.WriteLine("if (!_controllerFactoryCache.TryGetValue(controllerType, out var factory))"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("factory = CreateControllerFactory(controllerType);"); + writer.WriteLine("_controllerFactoryCache.TryAdd(controllerType, factory);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("return factory(this, []);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // Release method + writer.WriteLine("void global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator.Release(global::Microsoft.AspNetCore.Mvc.ControllerContext context, object controller)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(context);"); + writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(controller);"); + writer.WriteLine("if (controller is global::System.IDisposable disposable)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("disposable.Dispose();"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // ReleaseAsync method + writer.WriteLine("global::System.Threading.Tasks.ValueTask global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator.ReleaseAsync(global::Microsoft.AspNetCore.Mvc.ControllerContext context, object controller)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(context);"); + writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(controller);"); + writer.WriteLine("if (controller is global::System.IAsyncDisposable asyncDisposable)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("return asyncDisposable.DisposeAsync();"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("((global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator)this).Release(context, controller);"); + writer.WriteLine("return default;"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes IComponentActivator implementation using ActivatorUtilities with ObjectFactory caching. + /// + private static void WriteIComponentActivatorImplementation(SourceWriter writer, ContainerModel container) + { + writer.WriteLine("#region IComponentActivator"); + writer.WriteLine(); + + // Static cache field + writer.WriteLine("private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _componentFactoryCache = new();"); + writer.WriteLine("private static global::Microsoft.Extensions.DependencyInjection.ObjectFactory CreateComponentFactory("); + writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers("); + writer.WriteLine(" global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] global::System.Type t)"); + writer.WriteLine(" => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateFactory(t, global::System.Type.EmptyTypes);"); + writer.WriteLine(); + + // CreateInstance method + writer.WriteLine("global::Microsoft.AspNetCore.Components.IComponent global::Microsoft.AspNetCore.Components.IComponentActivator.CreateInstance([global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] global::System.Type componentType)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if (!typeof(global::Microsoft.AspNetCore.Components.IComponent).IsAssignableFrom(componentType))"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("throw new global::System.ArgumentException($\"The type {componentType.FullName} does not implement {nameof(global::Microsoft.AspNetCore.Components.IComponent)}.\", nameof(componentType));"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("var instance = GetService(componentType);"); + writer.WriteLine("if (instance is global::Microsoft.AspNetCore.Components.IComponent component) return component;"); + writer.WriteLine(); + writer.WriteLine("if (!_componentFactoryCache.TryGetValue(componentType, out var factory))"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("factory = CreateComponentFactory(componentType);"); + writer.WriteLine("_componentFactoryCache.TryAdd(componentType, factory);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("return (global::Microsoft.AspNetCore.Components.IComponent)factory(this, []);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + private static void WriteIComponentPropertyActivatorImplementation(SourceWriter writer, ContainerModel container, bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region IComponentPropertyActivator"); + writer.WriteLine(); + + // Static cache field + writer.WriteLine("private static readonly global::System.Collections.Concurrent.ConcurrentDictionary> _propertyActivatorCache = new();"); + writer.WriteLine(); + + // GetActivator method - explicit interface implementation + writer.WriteLine("global::System.Action global::Microsoft.AspNetCore.Components.IComponentPropertyActivator.GetActivator("); + writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All)] global::System.Type componentType)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("if (!_propertyActivatorCache.TryGetValue(componentType, out var activator))"); + writer.WriteLine("{"); + writer.Indentation++; + + // No-op optimization: only when also implementing IComponentActivator + if(container.ImplementComponentActivator) + { + if(effectiveUseSwitchStatement) + { + writer.WriteLine("activator = global::System.Array.Exists(_localResolvers, e => e.Key.Equals(new ServiceIdentifier(componentType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey)))"); + } + else + { + writer.WriteLine("activator = _serviceResolvers.ContainsKey(new ServiceIdentifier(componentType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))"); + } + writer.WriteLine(" ? static (_, _) => { }"); + writer.WriteLine(" : CreateComponentPropertyInjector(componentType);"); + } + else + { + writer.WriteLine("activator = CreateComponentPropertyInjector(componentType);"); + } + + writer.WriteLine("_propertyActivatorCache.TryAdd(componentType, activator);"); + writer.Indentation--; + writer.WriteLine("}"); + + writer.WriteLine("return activator;"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // CreateComponentPropertyInjector - reflection fallback method + WriteCreateComponentPropertyInjectorMethod(writer); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + private static void WriteCreateComponentPropertyInjectorMethod(SourceWriter writer) + { + writer.WriteLine("private static global::System.Action CreateComponentPropertyInjector("); + writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All)] global::System.Type componentType)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("const global::System.Reflection.BindingFlags flags = global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic;"); + writer.WriteLine("global::System.Collections.Generic.List<(global::System.Reflection.PropertyInfo Property, object? Key)>? injectables = null;"); + writer.WriteLine(); + + // Walk inheritance chain + writer.WriteLine("for (var type = componentType; type is not null; type = type.BaseType)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("foreach (var property in type.GetProperties(flags))"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("if (property.DeclaringType != type) continue;"); + writer.WriteLine("var injectAttr = property.GetCustomAttributes(true)"); + writer.WriteLine(" .FirstOrDefault(a => a.GetType().Name is \"InjectAttribute\" or \"IocInjectAttribute\");"); + writer.WriteLine("if (injectAttr is null) continue;"); + writer.WriteLine(); + writer.WriteLine("var keyProp = injectAttr.GetType().GetProperty(\"Key\");"); + writer.WriteLine("var key = keyProp?.GetValue(injectAttr);"); + writer.WriteLine("injectables ??= new();"); + writer.WriteLine("injectables.Add((property, key));"); + + writer.Indentation--; + writer.WriteLine("}"); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // Return no-op if no injectables found + writer.WriteLine("if (injectables is null) return static (_, _) => { };"); + writer.WriteLine(); + + // Return injection delegate + writer.WriteLine("return (serviceProvider, component) =>"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("foreach (var (property, serviceKey) in injectables)"); + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine("object? value;"); + writer.WriteLine("if (serviceKey is not null)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if (serviceProvider is not global::Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider keyedProvider)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("throw new global::System.InvalidOperationException($\"Cannot provide a value for property '{property.Name}' on type '{componentType.FullName}'. The service provider does not implement 'IKeyedServiceProvider'.\");"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("value = keyedProvider.GetRequiredKeyedService(property.PropertyType, serviceKey);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("else"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("value = serviceProvider.GetService(property.PropertyType) ?? throw new global::System.InvalidOperationException($\"Cannot provide a value for property '{property.Name}' on type '{componentType.FullName}'. There is no registered service of type '{property.PropertyType}'.\");"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("property.SetValue(component, value);"); + + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("};"); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + } + + /// + /// Writes the __HotReloadHandler nested class for hot reload cache invalidation. + /// + private static void WriteHotReloadHandler(SourceWriter writer, ContainerModel container) + { + writer.WriteLine("[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]"); + writer.WriteLine("internal static class __HotReloadHandler"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("public static void ClearCache(global::System.Type[]? _)"); + writer.WriteLine("{"); + writer.Indentation++; + if(container.ImplementComponentActivator) + { + writer.WriteLine("_componentFactoryCache.Clear();"); + } + if(container.ImplementComponentPropertyActivator) + { + writer.WriteLine("_propertyActivatorCache.Clear();"); + } + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + } + + /// + /// Writes the disposal body for both sync and async disposal methods. + /// + private static void WriteDisposalBody( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool isAsync) + { + // Dispose scoped services if this is a scope + writer.WriteLine("if(!_isRootScope)"); + writer.WriteLine("{"); + writer.Indentation++; + + WriteDisposalCalls(writer, groups.ScopedEntries, container.ImportedModules, isAsync); + writer.WriteLine("return;"); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // Root scope disposal + WriteDisposalCalls(writer, groups.SingletonEntries, container.ImportedModules, isAsync); + } + + /// + /// Writes disposal calls for services and modules. + /// + private static void WriteDisposalCalls( + SourceWriter writer, + ImmutableEquatableArray services, + ImmutableEquatableArray modules, + bool isAsync) + { + for(var i = services.Length - 1; i >= 0; i--) + { + services[i].WriteDisposal(writer, isAsync); + } + + var moduleMethod = isAsync + ? "await {0}.DisposeAsync()" + : "{0}.Dispose()"; + + foreach(var module in modules) + { + var fieldName = GetModuleFieldName(module.Name); + writer.WriteLine(string.Format(moduleMethod, fieldName) + ";"); + } + } + + /// + /// Writes the static helper methods for disposal. + /// + private static void WriteDisposalHelperMethods(SourceWriter writer, bool hasAsyncInitServices) + { + // Helper method to throw ObjectDisposedException if disposed + writer.WriteLine("private void ThrowIfDisposed()"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("ObjectDisposedException.ThrowIf(_disposed != 0, GetType());"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // Helper method for async disposal + writer.WriteLine("private static async ValueTask DisposeServiceAsync(object? service)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync();"); + writer.WriteLine("else if(service is IDisposable disposable) disposable.Dispose();"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + // Helper method for sync disposal + writer.WriteLine("private static void DisposeService(object? service)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(service is IDisposable disposable) disposable.Dispose();"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + if(hasAsyncInitServices) + { + // Overload for async-init services stored as Task? + writer.WriteLine("private static async ValueTask DisposeServiceAsync(Task? task)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(task is { IsCompletedSuccessfully: true })"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("await DisposeServiceAsync(await task);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("catch(Exception ex)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("private static void DisposeService(Task? task)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("if(task is { IsCompletedSuccessfully: true })"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("try"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult());"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine("catch(Exception ex)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex);"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + } + } + + private static IEnumerable<(string ServiceType, string? Key, string ResolverExpression)> EnumerateInterfaceServiceResolvers(ContainerRegistrationGroups groups) + { + foreach(var kvp in groups.LastWinsByServiceType) + { + if(kvp.Value is not ServiceContainerEntry serviceEntry) + { + continue; + } + + yield return (serviceEntry.Registration.ServiceType.Name, kvp.Key.Key, GetInterfaceResolverExpression(kvp.Value)); + } + } + + private static IEnumerable<(string ServiceType, string? Key, string ResolverMethodName)> EnumerateLocalResolverMappings(ContainerRegistrationGroups groups) + { + foreach(var kvp in groups.LastWinsByServiceType) + { + if(!TryGetResolverMethodName(kvp.Value, out var resolverMethodName)) + { + continue; + } + + yield return (kvp.Key.ServiceType, kvp.Key.Key, resolverMethodName); + } + } + + private static IEnumerable<(string ServiceType, string? Key, string ResolverExpression)> EnumerateKeyedServiceResolvers(ContainerRegistrationGroups groups) + { + foreach(var kvp in groups.LastWinsByServiceType) + { + if(kvp.Key.Key is null) + { + continue; + } + + yield return (kvp.Key.ServiceType, kvp.Key.Key, GetInterfaceResolverExpression(kvp.Value)); + } + } + + private static void WriteWrapperLocalResolverEntries( + SourceWriter writer, + ImmutableEquatableArray wrapperEntries) + { + var kvpEntries = wrapperEntries + .OfType() + .ToList(); + + if(kvpEntries.Count > 0) + { + writer.WriteLine(); + writer.WriteLine("// KeyValuePair resolvers"); + + foreach(var group in kvpEntries.GroupBy(static entry => (entry.KeyTypeName, entry.ValueTypeName))) + { + group.First().WriteLocalResolverEntries(writer); + } + } + + var lazyEntries = wrapperEntries + .OfType() + .ToList(); + + if(lazyEntries.Count > 0) + { + writer.WriteLine(); + writer.WriteLine("// Lazy wrapper resolvers"); + + foreach(var group in lazyEntries.GroupBy(static entry => entry.InnerServiceTypeName)) + { + group.Last().WriteLocalResolverEntries(writer); + } + } + + var funcEntries = wrapperEntries + .OfType() + .ToList(); + + if(funcEntries.Count > 0) + { + writer.WriteLine(); + writer.WriteLine("// Func wrapper resolvers"); + + foreach(var group in funcEntries.GroupBy(static entry => entry.InnerServiceTypeName)) + { + group.Last().WriteLocalResolverEntries(writer); + } + } + } + + private static Dictionary BuildEntryByResolverMethodName(ContainerRegistrationGroups groups) + { + var map = new Dictionary(StringComparer.Ordinal); + + foreach(var entry in groups.SingletonEntries) + AddEntryByResolverMethodName(map, entry); + foreach(var entry in groups.ScopedEntries) + AddEntryByResolverMethodName(map, entry); + foreach(var entry in groups.TransientEntries) + AddEntryByResolverMethodName(map, entry); + + return map; + } + + private static void AddEntryByResolverMethodName(Dictionary map, ContainerEntry entry) + { + if(!TryGetResolverMethodName(entry, out var resolverMethodName)) + return; + + if(!map.ContainsKey(resolverMethodName)) + { + map[resolverMethodName] = entry; + } + } + + private static bool TryGetResolverMethodName(ContainerEntry entry, out string resolverMethodName) + { + switch(entry) + { + case InstanceContainerEntry instance: + resolverMethodName = instance.ResolverMethodName; + return true; + case EagerContainerEntry eager: + resolverMethodName = eager.ResolverMethodName; + return true; + case LazyThreadSafeContainerEntry lazy: + resolverMethodName = lazy.ResolverMethodName; + return true; + case TransientContainerEntry transient: + resolverMethodName = transient.ResolverMethodName; + return true; + case AsyncContainerEntry asyncSingletonOrScoped: + resolverMethodName = asyncSingletonOrScoped.ResolverMethodName; + return true; + case AsyncTransientContainerEntry asyncTransient: + resolverMethodName = asyncTransient.ResolverMethodName; + return true; + default: + resolverMethodName = string.Empty; + return false; + } + } + + private static string GetResolverExpression(ContainerEntry entry) + { + return entry switch + { + InstanceContainerEntry instance => $"static _ => {instance.Registration.Instance}", + EagerContainerEntry eager => $"static c => c.{eager.FieldName}!", + LazyThreadSafeContainerEntry lazy => $"static c => c.{lazy.ResolverMethodName}()", + TransientContainerEntry transient => $"static c => c.{transient.ResolverMethodName}()", + AsyncContainerEntry asyncSingletonOrScoped => $"static c => c.{GetAsyncResolverMethodName(asyncSingletonOrScoped.ResolverMethodName)}()", + AsyncTransientContainerEntry asyncTransient => $"static c => c.{GetAsyncCreateMethodName(asyncTransient.ResolverMethodName)}()", + _ => throw new InvalidOperationException($"Unsupported container entry type: {entry.GetType().Name}") + }; + } + + private static string GetInterfaceResolverExpression(ContainerEntry entry) + { + return entry switch + { + InstanceContainerEntry instance => instance.Registration.Instance!, + AsyncContainerEntry asyncSingletonOrScoped => $"{GetAsyncResolverMethodName(asyncSingletonOrScoped.ResolverMethodName)}()", + AsyncTransientContainerEntry asyncTransient => $"{GetAsyncCreateMethodName(asyncTransient.ResolverMethodName)}()", + ServiceContainerEntry serviceEntry => $"{serviceEntry.ResolverMethodName}()", + _ => throw new InvalidOperationException($"Unsupported container entry type: {entry.GetType().Name}") + }; + } +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerOutputModels.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerOutputModels.cs new file mode 100644 index 0000000..e730553 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ContainerOutputModels.cs @@ -0,0 +1,354 @@ +using InjectionMemberModel = SourceGen.Ioc.SourceGenerator.Models.InjectionMemberData; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + private sealed record class DirectServiceDependency(string ResolverMethodName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"{ResolverMethodName}()"; + } + } + + private sealed record class CollectionDependency(string ArrayMethodName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"{ArrayMethodName}()"; + } + } + + private sealed record class LazyFieldReferenceDependency(string FieldName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return FieldName; + } + } + + private sealed record class LazyInlineDependency(string ServiceTypeName, ResolvedDependency Inner) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + var innerExpr = Inner.FormatExpression(isOptional); + return $"new global::System.Lazy<{ServiceTypeName}>(() => {innerExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; + } + } + + private sealed record class FuncFieldReferenceDependency(string FieldName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return FieldName; + } + } + + private sealed record class FuncInlineDependency(string ServiceTypeName, ResolvedDependency Inner) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + var innerExpr = Inner.FormatExpression(isOptional); + return $"new global::System.Func<{ServiceTypeName}>(() => {innerExpr})"; + } + } + + private sealed record class MultiParamFuncDependency( + string ReturnTypeName, + ImmutableEquatableArray InputParameters, + ImmutableEquatableArray ConstructorParameters, + ImmutableEquatableArray InjectionMembers, + ImmutableEquatableArray Decorators, + string? ImplementationTypeName = null) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + var inputArgNames = new string[InputParameters.Length]; + var inputArgTypeNames = new string[InputParameters.Length]; + var inputArgUsed = new bool[InputParameters.Length]; + + for(var i = 0; i < InputParameters.Length; i++) + { + inputArgNames[i] = $"arg{i}"; + inputArgTypeNames[i] = InputParameters[i].Type.Name; + inputArgUsed[i] = false; + } + + var lambdaParams = string.Join(", ", InputParameters.Select(static (param, i) => $"{param.Type.Name} arg{i}")); + + var statements = new List(); + var ctorEntries = new List<(string Name, string? Value)>(ConstructorParameters.Length); + + var resolvedParamIndex = 0; + foreach(var param in ConstructorParameters) + { + var matchedArg = TryConsumeMatchingFuncInputArg(param.Parameter.Type.Name, inputArgNames, inputArgTypeNames, inputArgUsed); + if(matchedArg is not null) + { + ctorEntries.Add((param.Parameter.Name, matchedArg)); + continue; + } + + var paramVar = $"p{resolvedParamIndex}"; + var expr = param.Dependency.FormatExpression(param.IsOptional); + statements.Add($"var {paramVar} = {expr};"); + ctorEntries.Add((param.Parameter.Name, paramVar)); + resolvedParamIndex++; + } + + var propertyInits = new List(); + var propertyIndex = 0; + foreach(var injectionMember in InjectionMembers) + { + var member = injectionMember.Member; + if(member.MemberType is not (InjectionMemberType.Property or InjectionMemberType.Field)) + continue; + + var memberType = member.Type; + if(memberType is null) + continue; + + var matchedArg = TryConsumeMatchingFuncInputArg(memberType.Name, inputArgNames, inputArgTypeNames, inputArgUsed); + if(matchedArg is not null) + { + propertyInits.Add($"{member.Name} = {matchedArg}"); + continue; + } + + if(injectionMember.Dependency is null) + { + throw new InvalidOperationException($"Missing resolved dependency for injection member '{member.Name}'."); + } + + var memberVar = $"s0_p{propertyIndex}"; + var expr = injectionMember.Dependency.FormatExpression(member.IsNullable); + statements.Add($"var {memberVar} = {expr};"); + propertyInits.Add($"{member.Name} = {memberVar}"); + propertyIndex++; + } + + var implementationType = ImplementationTypeName ?? ReturnTypeName; + var ctorArgs = BuildArgumentListFromEntries(ctorEntries); + var initializerPart = propertyInits.Count > 0 ? $" {{ {string.Join(", ", propertyInits)} }}" : string.Empty; + var ctorInvocation = BuildConstructorInvocation(implementationType, ctorArgs, initializerPart); + statements.Add($"var s0 = {ctorInvocation};"); + + var methodIndex = 0; + foreach(var injectionMember in InjectionMembers) + { + var member = injectionMember.Member; + if(member.MemberType != InjectionMemberType.Method) + continue; + + var methodParams = member.Parameters ?? []; + var methodEntries = new List<(string Name, string? Value)>(methodParams.Length); + foreach(var param in methodParams) + { + var matchedArg = TryConsumeMatchingFuncInputArg(param.Type.Name, inputArgNames, inputArgTypeNames, inputArgUsed); + if(matchedArg is not null) + { + methodEntries.Add((param.Name, matchedArg)); + continue; + } + + if(injectionMember.Dependency is null) + { + throw new InvalidOperationException($"Missing resolved dependency for method parameter '{param.Name}' in member '{member.Name}'."); + } + + var paramVar = $"s0_m{methodIndex}"; + var expr = injectionMember.Dependency.FormatExpression(param.IsOptional); + statements.Add($"var {paramVar} = {expr};"); + methodEntries.Add((param.Name, paramVar)); + methodIndex++; + } + + var methodArgs = BuildArgumentListFromEntries(methodEntries); + statements.Add($"s0.{member.Name}({methodArgs});"); + } + + statements.Add("return s0;"); + + var funcTypeName = BuildMultiParamFuncTypeName(InputParameters, ReturnTypeName); + return $"new {funcTypeName}(({lambdaParams}) => {{ {string.Join(" ", statements)} }})"; + } + + private static string BuildMultiParamFuncTypeName( + ImmutableEquatableArray inputParameters, + string returnTypeName) + { + if(inputParameters.Length == 0) + { + return $"global::System.Func<{returnTypeName}>"; + } + + var inputTypeList = string.Join(", ", inputParameters.Select(static p => p.Type.Name)); + return $"global::System.Func<{inputTypeList}, {returnTypeName}>"; + } + + private static string? TryConsumeMatchingFuncInputArg( + string requestedTypeName, + string[] inputArgNames, + string[] inputArgTypeNames, + bool[] inputArgUsed) + { + for(var i = 0; i < inputArgTypeNames.Length; i++) + { + if(inputArgUsed[i]) + continue; + + if(!string.Equals(inputArgTypeNames[i], requestedTypeName, StringComparison.Ordinal)) + continue; + + inputArgUsed[i] = true; + return inputArgNames[i]; + } + + return null; + } + } + + private sealed record class TaskFromResultDependency(ResolvedDependency Inner, string TypeName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"global::System.Threading.Tasks.Task.FromResult(({TypeName}){Inner.FormatExpression(false)})"; + } + } + + private sealed record class TaskAsyncDependency(string AsyncMethodName, string TypeName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"((global::System.Func>)(async () => ({TypeName})(await {AsyncMethodName}())))()"; + } + } + + private sealed record class KvpInlineDependency( + string KeyType, + string ValueType, + string KeyExpr, + ResolvedDependency Inner) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"new global::System.Collections.Generic.KeyValuePair<{KeyType}, {ValueType}>({KeyExpr}, {Inner.FormatExpression(isOptional)})"; + } + } + + private sealed record class KvpResolverDependency(string MethodName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"{MethodName}()"; + } + } + + private sealed record class DictionaryResolverDependency(string MethodName) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return $"{MethodName}()"; + } + } + + private sealed record class DictionaryFallbackDependency( + string KvpTypeName, + bool IsKeyed, + string? Key) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + if(IsKeyed) + { + return $"GetKeyedServices<{KvpTypeName}>({Key}).ToDictionary()"; + } + + return $"GetServices<{KvpTypeName}>().ToDictionary()"; + } + } + + private sealed record class ServiceProviderSelfDependency : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return "this"; + } + } + + private sealed record class FallbackProviderDependency( + string TypeName, + string? Key, + bool IsOptional) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return BuildServiceProviderFallbackExpression(TypeName, Key, IsOptional || isOptional); + } + + private static string BuildServiceProviderFallbackExpression( + string typeName, + string? key, + bool isOptional) + { + if(key is not null) + { + return isOptional + ? $"GetKeyedService(typeof({typeName}), {key}) as {typeName}" + : $"({typeName})GetRequiredKeyedService(typeof({typeName}), {key})"; + } + + return isOptional + ? $"GetService(typeof({typeName})) as {typeName}" + : $"({typeName})GetRequiredService(typeof({typeName}))"; + } + } + + private sealed record class CollectionFallbackDependency( + string ElementType, + bool IsKeyed, + string? Key) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + if(IsKeyed) + { + return $"GetKeyedServices<{ElementType}>({Key})"; + } + + return $"GetServices<{ElementType}>()"; + } + } + + private sealed record class ServiceKeyLiteralDependency(string KeyType, string KeyValue) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return KeyValue; + } + } + + private sealed record class InstanceExpressionDependency(string Expression) : ResolvedDependency + { + public override string FormatExpression(bool isOptional) + { + return Expression; + } + } + + private readonly record struct ResolvedConstructorParameter( + ParameterData Parameter, + ResolvedDependency Dependency, + bool IsOptional); + + private readonly record struct ResolvedInjectionMember( + InjectionMemberModel Member, + ResolvedDependency? Dependency, + ImmutableEquatableArray ParameterDependencies); + + private sealed record class ResolvedDecorator( + ServiceRegistrationModel Decorator, + ImmutableEquatableArray Parameters, + ImmutableEquatableArray InjectionMembers); +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.Resolvers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.Resolvers.cs new file mode 100644 index 0000000..1e4a2a3 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.Resolvers.cs @@ -0,0 +1,340 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Writes individual service resolver methods. + /// + private static void WriteServiceResolverSection( + SourceWriter writer, + ContainerRegistrationGroups groups) + { + writer.WriteLine("#region Service Resolution"); + writer.WriteLine(); + + var writtenMethods = new HashSet(StringComparer.Ordinal); + + WriteServiceResolverGroup(writer, groups.SingletonEntries, writtenMethods, writeField: true); + WriteServiceResolverGroup(writer, groups.ScopedEntries, writtenMethods, writeField: true); + WriteServiceResolverGroup(writer, groups.TransientEntries, writtenMethods, writeField: false); + + foreach(var entry in groups.CollectionEntries) + { + entry.WriteResolver(writer); + writer.WriteLine(); + } + + var kvpEntries = groups.WrapperEntries + .OfType() + .ToList(); + + if(kvpEntries.Count > 0) + { + writer.WriteLine("// KeyValuePair resolver methods"); + writer.WriteLine(); + + foreach(var entry in kvpEntries) + { + entry.WriteResolver(writer); + writer.WriteLine(); + } + + WriteKvpCollectionResolvers(writer, kvpEntries); + } + + var lazyEntries = groups.WrapperEntries + .OfType() + .ToList(); + + if(lazyEntries.Count > 0) + { + writer.WriteLine("// Lazy wrapper fields"); + writer.WriteLine(); + + foreach(var entry in lazyEntries) + { + entry.WriteField(writer); + } + + writer.WriteLine(); + + foreach(var entry in lazyEntries) + { + if(!entry.EmitCollectionResolver) + continue; + + entry.WriteCollectionResolver(writer); + writer.WriteLine(); + } + } + + var funcEntries = groups.WrapperEntries + .OfType() + .ToList(); + + if(funcEntries.Count > 0) + { + writer.WriteLine("// Func wrapper fields"); + writer.WriteLine(); + + foreach(var entry in funcEntries) + { + entry.WriteField(writer); + } + + writer.WriteLine(); + + foreach(var entry in funcEntries) + { + if(!entry.EmitCollectionResolver) + continue; + + entry.WriteCollectionResolver(writer); + writer.WriteLine(); + } + } + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + private static void WriteServiceResolverGroup( + SourceWriter writer, + ImmutableEquatableArray entries, + HashSet writtenMethods, + bool writeField) + { + foreach(var entry in entries) + { + if(!TryGetResolverMethodName(entry, out var resolverMethodName)) + continue; + + if(!writtenMethods.Add(resolverMethodName)) + continue; + + if(writeField) + { + entry.WriteField(writer); + + if(entry is AsyncContainerEntry) + { + writer.WriteLine(); + } + } + + entry.WriteResolver(writer); + writer.WriteLine(); + } + } + + private static void WriteKvpCollectionResolvers( + SourceWriter writer, + List kvpEntries) + { + var grouped = kvpEntries + .GroupBy(static e => (e.KeyTypeName, e.ValueTypeName)) + .ToList(); + + foreach(var group in grouped) + { + var (keyTypeName, valueTypeName) = group.Key; + var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{keyTypeName}, {valueTypeName}>"; + var arrayMethodName = GetKvpArrayResolverMethodName(keyTypeName, valueTypeName); + + writer.WriteLine($"private {kvpTypeName}[] {arrayMethodName}() =>"); + writer.Indentation++; + writer.WriteLine("["); + writer.Indentation++; + + foreach(var entry in group) + { + writer.WriteLine($"{entry.KvpResolverMethodName}(),"); + } + + writer.Indentation--; + writer.WriteLine("];"); + writer.Indentation--; + writer.WriteLine(); + } + + foreach(var group in grouped) + { + var (keyTypeName, valueTypeName) = group.Key; + var dictionaryMethodName = GetKvpDictionaryResolverMethodName(keyTypeName, valueTypeName); + + writer.WriteLine($"private global::System.Collections.Generic.Dictionary<{keyTypeName}, {valueTypeName}> {dictionaryMethodName}() =>"); + writer.Indentation++; + writer.WriteLine($"new global::System.Collections.Generic.Dictionary<{keyTypeName}, {valueTypeName}>()"); + writer.WriteLine("{"); + writer.Indentation++; + + foreach(var entry in group) + { + writer.WriteLine($"[{entry.KeyExpr}] = {entry.ResolverMethodName}(),"); + } + + writer.Indentation--; + writer.WriteLine("};"); + writer.Indentation--; + writer.WriteLine(); + } + } + + /// + /// Writes individual service resolver methods. + /// + private static void WritePartialAccessorImplementations( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups) + { + if(container.PartialAccessors.Length == 0) + return; + + writer.WriteLine("#region Partial Accessor Implementations"); + writer.WriteLine(); + + foreach(var accessor in container.PartialAccessors) + { + var isTaskReturn = accessor.Kind == PartialAccessorKind.Method + && TryExtractTaskInnerType(accessor.ReturnTypeName, out _); + + var resolveExpression = ResolvePartialAccessorExpression(accessor, container, groups); + + switch(accessor.Kind) + { + case PartialAccessorKind.Method: + // Async partial methods (returning Task) require the 'async' modifier + if(isTaskReturn) + { + writer.WriteLine($"public partial async {accessor.ReturnTypeName} {accessor.Name}() => {resolveExpression};"); + } + else + { + writer.WriteLine($"public partial {accessor.ReturnTypeName}{(accessor.IsNullable ? "?" : "")} {accessor.Name}() => {resolveExpression};"); + } + break; + + case PartialAccessorKind.Property: + writer.WriteLine($"public partial {accessor.ReturnTypeName}{(accessor.IsNullable ? "?" : "")} {accessor.Name} {{ get => {resolveExpression}; }}"); + break; + } + + writer.WriteLine(); + } + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Resolves the expression to use for a partial accessor implementation. + /// Looks up the registration by return type and optional key, with fallback to IServiceProvider. + /// For Task<T> return types, routes through the async resolver. + /// + private static string ResolvePartialAccessorExpression( + PartialAccessorData accessor, + ContainerModel container, + ContainerRegistrationGroups groups) + { + var serviceType = accessor.ReturnTypeName; + var key = accessor.Key; + + // Handle Task return types β€” route through the async resolver. + if(TryExtractTaskInnerType(serviceType, out var innerTypeName)) + { + if(groups.LastWinsByServiceType.TryGetValue((innerTypeName, key), out var entry)) + { + switch(entry) + { + case AsyncContainerEntry asyncEntry: + return $"await {GetAsyncResolverMethodName(asyncEntry.ResolverMethodName)}()"; + case AsyncTransientContainerEntry asyncTransientEntry: + return $"await {GetAsyncCreateMethodName(asyncTransientEntry.ResolverMethodName)}()"; + case InstanceContainerEntry instanceEntry: + return $"global::System.Threading.Tasks.Task.FromResult(({innerTypeName}){instanceEntry.Registration.Instance})"; + case ServiceContainerEntry serviceEntry: + return $"global::System.Threading.Tasks.Task.FromResult(({innerTypeName}){serviceEntry.ResolverMethodName}())"; + } + } + + // Fallback: delegate to IServiceProvider if available + if(container.IntegrateServiceProvider) + { + if(key is not null) + return $"({serviceType})GetRequiredKeyedService(typeof({serviceType}), {key})"; + return $"({serviceType})GetRequiredService(typeof({serviceType}))"; + } + + return $"""throw new global::System.InvalidOperationException("Service '{innerTypeName}' is not registered.")"""; + } + + // Try to find direct resolver in this container + if(groups.LastWinsByServiceType.TryGetValue((serviceType, key), out var directEntry)) + { + switch(directEntry) + { + case InstanceContainerEntry instanceEntry: + return instanceEntry.Registration.Instance!; + case AsyncContainerEntry asyncEntry: + return $"{GetAsyncResolverMethodName(asyncEntry.ResolverMethodName)}().ConfigureAwait(false).GetAwaiter().GetResult()"; + case AsyncTransientContainerEntry asyncTransientEntry: + return $"{GetAsyncCreateMethodName(asyncTransientEntry.ResolverMethodName)}().ConfigureAwait(false).GetAwaiter().GetResult()"; + case ServiceContainerEntry serviceEntry: + return $"{serviceEntry.ResolverMethodName}()"; + } + } + + // Fallback to GetService/GetRequiredService (only when IntegrateServiceProvider is enabled) + if(container.IntegrateServiceProvider) + { + if(key is not null) + { + return accessor.IsNullable + ? $"GetKeyedService(typeof({serviceType}), {key}) as {serviceType}" + : $"({serviceType})GetRequiredKeyedService(typeof({serviceType}), {key})"; + } + + return accessor.IsNullable + ? $"GetService(typeof({serviceType})) as {serviceType}" + : $"({serviceType})GetRequiredService(typeof({serviceType}))"; + } + + // No resolver found and no fallback: throw (analyzer should have caught this) + return accessor.IsNullable + ? "default" + : $"""throw new global::System.InvalidOperationException("Service '{serviceType}' is not registered.")"""; + } + + /// + /// Tries to extract the inner type name from a + /// global::System.Threading.Tasks.Task<T> type name string. + /// Returns and sets if matched. + /// + private static bool TryExtractTaskInnerType(string typeName, out string innerTypeName) + { + const string TaskPrefix = "global::System.Threading.Tasks.Task<"; + if(typeName.StartsWith(TaskPrefix, StringComparison.Ordinal) + && typeName.EndsWith(">", StringComparison.Ordinal)) + { + innerTypeName = typeName[TaskPrefix.Length..^1]; + return true; + } + innerTypeName = string.Empty; + return false; + } + + /// + /// Gets the array resolver method name for IEnumerable<T>, IReadOnlyCollection<T>, IReadOnlyList<T>, T[] resolution. + /// + private static string GetArrayResolverMethodName(string serviceType) + { + var baseName = GetSafeIdentifier(serviceType); + return $"GetAll{baseName}Array"; + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.Structure.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.Structure.cs new file mode 100644 index 0000000..2e6496f --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.Structure.cs @@ -0,0 +1,398 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Writes the namespace and class declaration with all implemented interfaces. + /// + private static void WriteContainerNamespaceAndClass( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool canGenerateServiceProviderFactory, + bool effectiveUseSwitchStatement) + { + // Write namespace if not global + bool hasNamespace = !string.IsNullOrEmpty(container.ContainerNamespace); + if(hasNamespace) + { + writer.WriteLine($"namespace {container.ContainerNamespace};"); + writer.WriteLine(); + } + + // Get interface list + var interfaces = GetContainerInterfaces(container, canGenerateServiceProviderFactory); + + // Write class declaration + writer.WriteLine($"partial class {container.ClassName} : {string.Join(", ", interfaces)}"); + writer.WriteLine("{"); + writer.Indentation++; + + // Write fields + WriteContainerFields(writer, container); + + // Write constructors + WriteContainerConstructors(writer, container, groups, effectiveUseSwitchStatement); + + // Write service resolver methods + WriteServiceResolverSection(writer, groups); + + // Write partial accessor implementations (user-declared partial methods/properties) + WritePartialAccessorImplementations(writer, container, groups); + + // Write IServiceProvider implementation + WriteIServiceProviderImplementation(writer, container, groups, effectiveUseSwitchStatement); + + // Write IKeyedServiceProvider implementation + WriteIKeyedServiceProviderImplementation(writer, container, groups, effectiveUseSwitchStatement); + + // Write ISupportRequiredService implementation + WriteISupportRequiredServiceImplementation(writer, container); + + // Write ServiceProvider extension methods (generic overloads) + WriteServiceProviderExtensions(writer, container, effectiveUseSwitchStatement); + + // Write IServiceProviderIsService implementation + WriteIServiceProviderIsServiceImplementation(writer, container, groups, effectiveUseSwitchStatement); + + // Write IServiceScopeFactory implementation + WriteIServiceScopeFactoryImplementation(writer, container); + + // Write IIocContainer implementation + WriteIIocContainerImplementation(writer, container, groups, effectiveUseSwitchStatement); + + // Write IServiceProviderFactory implementation (if DI package is available) + if(canGenerateServiceProviderFactory) + { + WriteIServiceProviderFactoryImplementation(writer, container); + } + + // Write Disposal implementation + WriteDisposalImplementation(writer, container, groups); + + // Write IControllerActivator implementation (if container declares the interface) + if(container.ImplementControllerActivator) + { + WriteIControllerActivatorImplementation(writer, container); + } + + // Write IComponentActivator implementation (if container declares the interface) + if(container.ImplementComponentActivator) + { + WriteIComponentActivatorImplementation(writer, container); + } + + // Write IComponentPropertyActivator implementation (if container declares the interface) + if(container.ImplementComponentPropertyActivator) + { + WriteIComponentPropertyActivatorImplementation(writer, container, effectiveUseSwitchStatement); + } + + // Write hot reload handler (if either component activator interface is implemented) + if(container.ImplementComponentActivator || container.ImplementComponentPropertyActivator) + { + WriteHotReloadHandler(writer, container); + } + + writer.Indentation--; + writer.WriteLine("}"); + } + + private static readonly string[] _FixedContainerInterfaces = [ + "IServiceProvider", + "IKeyedServiceProvider", + "IServiceProviderIsService", + "IServiceProviderIsKeyedService", + "ISupportRequiredService", + "IServiceScopeFactory", + "IServiceScope", + "IDisposable", + "IAsyncDisposable" + ]; + + /// + /// Builds the list of interfaces the container should implement. + /// + private static IEnumerable GetContainerInterfaces(ContainerModel container, bool canGenerateServiceProviderFactory) + { + yield return $"IIocContainer<{container.ContainerTypeName}>"; + + foreach(var i in _FixedContainerInterfaces) + yield return i; + + if(canGenerateServiceProviderFactory) + yield return "IServiceProviderFactory"; + } + + /// + /// Writes container fields (fallback provider, service storage, locks, etc.). + /// + private static void WriteContainerFields( + SourceWriter writer, + ContainerModel container) + { + // Fallback provider field (only if IntegrateServiceProvider is enabled) + if(container.IntegrateServiceProvider) + { + writer.WriteLine("private readonly IServiceProvider? _fallbackProvider;"); + } + + // Scope tracking + writer.WriteLine("private readonly bool _isRootScope = true;"); + writer.WriteLine("private int _disposed;"); + writer.WriteLine(); + + // Imported module fields + foreach(var module in container.ImportedModules) + { + var fieldName = GetModuleFieldName(module.Name); + writer.WriteLine($"private readonly {module.Name} {fieldName};"); + } + + if(container.ImportedModules.Length > 0) + { + writer.WriteLine(); + } + } + + /// + /// Writes a service instance field and synchronization field based on ThreadSafeStrategy. + /// For eager services, fields are non-nullable and no synchronization is needed. + /// + private static void WriteServiceInstanceField( + SourceWriter writer, + ThreadSafeStrategy strategy, + ServiceRegistrationModel reg, + string fieldName, + bool hasDecorators, + bool isEager) + { + // When there are decorators, field type is ServiceType (interface), otherwise ImplementationType + var typeName = hasDecorators ? reg.ServiceType.Name : reg.ImplementationType.Name; + + writer.WriteLine(isEager + ? $"private {typeName} {fieldName} = null!;" + : $"private {typeName}? {fieldName};"); + + if(isEager) + return; + + // SpinLock must NOT be readonly because Enter/Exit mutate it + var syncFieldDeclaration = strategy switch + { + ThreadSafeStrategy.Lock => $"private readonly Lock {fieldName}Lock = new();", + ThreadSafeStrategy.SemaphoreSlim => $"private readonly SemaphoreSlim {fieldName}Semaphore = new(1, 1);", + ThreadSafeStrategy.SpinLock => $"private SpinLock {fieldName}SpinLock = new(false);", + _ => null + }; + + if(syncFieldDeclaration is not null) + { + writer.WriteLine(syncFieldDeclaration); + } + } + + /// + /// Writes container constructors. + /// + private static void WriteContainerConstructors( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool effectiveUseSwitchStatement) + { + writer.WriteLine("#region Constructors"); + writer.WriteLine(); + + // Default constructor + writer.WriteLine("/// "); + writer.WriteLine("/// Creates a new standalone container without external service provider fallback."); + writer.WriteLine("/// "); + + // Need fallback provider constructor if IntegrateServiceProvider is enabled + var needsFallbackProvider = container.IntegrateServiceProvider; + + if(needsFallbackProvider) + { + writer.WriteLine($"public {container.ClassName}() : this((IServiceProvider?)null) {{ }}"); + } + else + { + // Standalone mode - no fallback provider + writer.WriteLine($"public {container.ClassName}()"); + writer.WriteLine("{"); + writer.Indentation++; + WriteConstructorBody(writer, container, groups, hasParameter: false, effectiveUseSwitchStatement); + writer.Indentation--; + writer.WriteLine("}"); + } + writer.WriteLine(); + + // Constructor with fallback provider (if enabled) + if(needsFallbackProvider) + { + writer.WriteLine("/// "); + writer.WriteLine("/// Creates a new container with optional fallback to external service provider."); + writer.WriteLine("/// "); + writer.WriteLine("/// Optional external service provider for unknown dependencies."); + writer.WriteLine($"public {container.ClassName}(IServiceProvider? fallbackProvider)"); + writer.WriteLine("{"); + writer.Indentation++; + writer.WriteLine("_fallbackProvider = fallbackProvider;"); + WriteConstructorBody(writer, container, groups, hasParameter: true, effectiveUseSwitchStatement); + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + } + + // Private constructor for scoped instances + writer.WriteLine($"private {container.ClassName}({container.ClassName} parent)"); + writer.WriteLine("{"); + writer.Indentation++; + if(needsFallbackProvider) + { + writer.WriteLine("_fallbackProvider = parent._fallbackProvider;"); + } + writer.WriteLine("_isRootScope = false;"); + + // Copy singleton references from parent (already filtered for non-open-generics) + foreach(var entry in groups.SingletonEntries) + { + if(!TryGetServiceFieldName(entry, out var fieldName)) + continue; + + writer.WriteLine($"{fieldName} = parent.{fieldName};"); + } + + // Create scopes for imported modules (so their scoped services are properly isolated) + foreach(var module in container.ImportedModules) + { + var fieldName = GetModuleFieldName(module.Name); + writer.WriteLine($"{fieldName} = ({module.Name})parent.{fieldName}.CreateScope().ServiceProvider;"); + } + + // Initialize eager scoped services + if(groups.ScopedEntries.Any(static entry => entry is EagerContainerEntry or AsyncContainerEntry { IsEager: true })) + { + writer.WriteLine(); + writer.WriteLine("// Initialize eager scoped services"); + foreach(var entry in groups.ScopedEntries) + { + entry.WriteEagerInit(writer); + } + } + + // Initialize Lazy/Func wrapper fields (each scope gets its own wrappers) + WriteWrapperInitializations(writer, groups.WrapperEntries); + + writer.Indentation--; + writer.WriteLine("}"); + writer.WriteLine(); + + writer.WriteLine("#endregion"); + writer.WriteLine(); + } + + /// + /// Writes the constructor body for building the service resolver dictionary. + /// + private static void WriteConstructorBody( + SourceWriter writer, + ContainerModel container, + ContainerRegistrationGroups groups, + bool hasParameter, + bool effectiveUseSwitchStatement) + { + // Initialize imported modules + foreach(var module in container.ImportedModules) + { + var fieldName = GetModuleFieldName(module.Name); + if(container.IntegrateServiceProvider && hasParameter) + { + writer.WriteLine($"{fieldName} = new {module.Name}(fallbackProvider);"); + } + else + { + writer.WriteLine($"{fieldName} = new {module.Name}();"); + } + } + + // Initialize eager singleton services + if(groups.SingletonEntries.Any(static entry => entry is EagerContainerEntry or AsyncContainerEntry { IsEager: true })) + { + writer.WriteLine(); + writer.WriteLine("// Initialize eager singletons"); + foreach(var entry in groups.SingletonEntries) + { + entry.WriteEagerInit(writer); + } + } + + // Initialize Lazy/Func wrapper fields + WriteWrapperInitializations(writer, groups.WrapperEntries); + } + + private static void WriteWrapperInitializations( + SourceWriter writer, + ImmutableEquatableArray wrapperEntries) + { + var lazyEntries = wrapperEntries.OfType().ToList(); + var funcEntries = wrapperEntries.OfType().ToList(); + + if(lazyEntries.Count == 0 && funcEntries.Count == 0) + return; + + if(lazyEntries.Count > 0) + { + writer.WriteLine(); + writer.WriteLine("// Initialize Lazy wrapper fields"); + foreach(var entry in lazyEntries) + { + entry.WriteInit(writer); + } + } + + if(funcEntries.Count > 0) + { + writer.WriteLine(); + writer.WriteLine("// Initialize Func wrapper fields"); + foreach(var entry in funcEntries) + { + entry.WriteInit(writer); + } + } + } + + private static bool TryGetServiceFieldName(ContainerEntry entry, out string fieldName) + { + switch(entry) + { + case EagerContainerEntry eager: + fieldName = eager.FieldName; + return true; + case LazyThreadSafeContainerEntry lazy: + fieldName = lazy.FieldName; + return true; + case AsyncContainerEntry asyncEntry: + fieldName = asyncEntry.FieldName; + return true; + default: + fieldName = string.Empty; + return false; + } + } + + /// + /// Gets the field name for an imported module. + /// + private static string GetModuleFieldName(string moduleName) + { + var baseName = GetSafeIdentifier(moduleName); + return $"_{char.ToLowerInvariant(baseName[0])}{baseName[1..]}"; + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.cs new file mode 100644 index 0000000..d003267 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/GenerateContainerOutput.cs @@ -0,0 +1,79 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Generates the container source code output. + /// + private static void GenerateContainerOutput( + in SourceProductionContext ctx, + ContainerWithGroups containerWithGroups, + string assemblyName, + MsBuildProperties msbuildProps, + bool hasDIPackage) + { + if((msbuildProps.Features & IocFeatures.Container) == 0) + return; + + var source = GenerateContainerSource(containerWithGroups, assemblyName, msbuildProps, hasDIPackage); + var fileName = $"{containerWithGroups.Container.ClassName}.Container.g.cs"; + ctx.AddSource(fileName, source); + } + + /// + /// Generates the container source code. + /// + private static string GenerateContainerSource( + ContainerWithGroups containerWithGroups, + string assemblyName, + MsBuildProperties msbuildProps, + bool hasDIPackage) + { + var writer = new SourceWriter(); + var container = containerWithGroups.Container; + var groups = containerWithGroups.Groups; + + // Determine if IServiceProviderFactory should be generated + // Only generate if IntegrateServiceProvider is true AND the DI package is referenced + var canGenerateServiceProviderFactory = container.IntegrateServiceProvider && hasDIPackage; + + // Effective UseSwitchStatement: when there are imported modules, always use FrozenDictionary + // because combining services from multiple sources requires dictionary-based lookup + var effectiveUseSwitchStatement = container.UseSwitchStatement && container.ImportedModules.Length == 0; + + WriteContainerHeader(writer); + + // Write [assembly: MetadataUpdateHandler] attribute for hot reload cache invalidation + if(container.ImplementComponentActivator || container.ImplementComponentPropertyActivator) + { + writer.WriteLine($"[assembly: global::System.Reflection.Metadata.MetadataUpdateHandler(typeof({container.ContainerTypeName}.__HotReloadHandler))]"); + writer.WriteLine(); + } + + WriteContainerNamespaceAndClass(writer, container, groups, canGenerateServiceProviderFactory, effectiveUseSwitchStatement); + + return writer.ToString(); + } + + /// + /// Writes the auto-generated header and using directives. + /// + private static void WriteContainerHeader(SourceWriter writer) + { + writer.WriteLine(AutoGeneratedHeader); + writer.WriteLine(NullableEnable); + writer.WriteLine("#pragma warning disable SGIOCEXP001"); + writer.WriteLine(); + writer.WriteLine("using System;"); + writer.WriteLine("using System.Collections.Frozen;"); + writer.WriteLine("using System.Collections.Generic;"); + writer.WriteLine("using System.Linq;"); + writer.WriteLine("using System.Threading;"); + writer.WriteLine("using System.Threading.Tasks;"); + writer.WriteLine("using Microsoft.Extensions.DependencyInjection;"); + writer.WriteLine("using SourceGen.Ioc;"); + writer.WriteLine(); + } +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ResolvedDependency.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ResolvedDependency.cs new file mode 100644 index 0000000..0f51eb5 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Container/ResolvedDependency.cs @@ -0,0 +1,9 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + private abstract record class ResolvedDependency + { + public abstract string FormatExpression(bool isOptional); + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/FuncRegistrationHelper.cs similarity index 70% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/FuncRegistrationHelper.cs index 2518b55..5245377 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FuncRegistrationHelper.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/FuncRegistrationHelper.cs @@ -15,15 +15,35 @@ partial class IocSourceGenerator /// The implementation injection members. /// The Func input type parameters. /// The tags inherited from the source registration. - private readonly record struct FuncRegistrationEntry( - string FuncServiceTypeName, - string InnerServiceTypeName, - string ImplementationTypeName, - ServiceLifetime Lifetime, - ImmutableEquatableArray? ImplementationTypeConstructorParams, - ImmutableEquatableArray ImplementationTypeInjectionMembers, - ImmutableEquatableArray InputTypes, - ImmutableEquatableArray Tags); + private readonly partial record struct FuncRegistrationEntry + { + public void WriteRegistration(SourceWriter writer) + { + var lifetime = Lifetime.Name; + + if(InputTypes.Length == 0) + { + var wrapperTypeName = $"global::System.Func<{InnerServiceTypeName}>"; + var resolveCall = $"sp.{GetRequiredService}<{ImplementationTypeName}>()"; + writer.WriteLine( + $"services.Add{lifetime}<{wrapperTypeName}>(({IServiceProviderGlobalTypeName} sp) => " + + $"new {wrapperTypeName}(() => {resolveCall}));"); + return; + } + + writer.WriteLine($"services.Add{lifetime}<{FuncServiceTypeName}>(({IServiceProviderGlobalTypeName} sp) =>"); + writer.Indentation++; + writer.WriteLine($"new {FuncServiceTypeName}(({BuildFuncLambdaParameters(InputTypes)}) =>"); + writer.WriteLine("{"); + writer.Indentation++; + + WriteFuncFactoryBody(writer, this); + + writer.Indentation--; + writer.WriteLine("}));"); + writer.Indentation--; + } + } /// /// Collects Func standalone registration entries needed by consumer dependencies. @@ -147,9 +167,9 @@ private static void ScanTypeForFuncNeeds( /// private static void WriteFuncRegistrations( SourceWriter writer, - List? entries) + ImmutableEquatableArray entries) { - if(entries is null or { Count: 0 }) + if(entries.Length == 0) return; writer.WriteLine(); @@ -157,29 +177,7 @@ private static void WriteFuncRegistrations( foreach(var entry in entries) { - var lifetime = entry.Lifetime.Name; - - if(entry.InputTypes.Length == 0) - { - var wrapperTypeName = $"global::System.Func<{entry.InnerServiceTypeName}>"; - var resolveCall = $"sp.{GetRequiredService}<{entry.ImplementationTypeName}>()"; - writer.WriteLine( - $"services.Add{lifetime}<{wrapperTypeName}>(({IServiceProviderGlobalTypeName} sp) => " + - $"new {wrapperTypeName}(() => {resolveCall}));"); - continue; - } - - writer.WriteLine($"services.Add{lifetime}<{entry.FuncServiceTypeName}>(({IServiceProviderGlobalTypeName} sp) =>"); - writer.Indentation++; - writer.WriteLine($"new {entry.FuncServiceTypeName}(({BuildFuncLambdaParameters(entry.InputTypes)}) =>"); - writer.WriteLine("{"); - writer.Indentation++; - - WriteFuncFactoryBody(writer, entry); - - writer.Indentation--; - writer.WriteLine("}));"); - writer.Indentation--; + entry.WriteRegistration(writer); } } @@ -200,20 +198,20 @@ private static void WriteFuncFactoryBody(SourceWriter writer, FuncRegistrationEn inputArgUsed[i] = false; } - var constructorParamEntries = new List<(string Name, string? Value, bool NeedsConditional)>(constructorParams.Length); + var constructorParamEntries = new List<(string Name, string? Value)>(constructorParams.Length); var resolvedParamIndex = 0; foreach(var param in constructorParams) { var matchedArg = TryConsumeMatchingFuncInputArg(param.Type.Name, inputArgNames, inputArgTypeNames, inputArgUsed); if(matchedArg is not null) { - constructorParamEntries.Add((param.Name, matchedArg, false)); + constructorParamEntries.Add((param.Name, matchedArg)); continue; } var paramVar = $"p{resolvedParamIndex}"; var resolvedVar = ResolveParamAndEmitVar(writer, param, paramVar, isKeyedRegistration: false, registrationKey: null); - constructorParamEntries.Add((param.Name, resolvedVar, false)); + constructorParamEntries.Add((param.Name, resolvedVar)); resolvedParamIndex++; } @@ -242,7 +240,7 @@ private static void WriteFuncFactoryBody(SourceWriter writer, FuncRegistrationEn propertyFieldIndex++; } - var constructorArgs = BuildArgumentListFromEntries([.. constructorParamEntries]); + var constructorArgs = BuildArgumentListFromEntries(constructorParamEntries); var initializerPart = propertyInits.Count > 0 ? $" {{ {string.Join(", ", propertyInits)} }}" : string.Empty; var constructorInvocation = BuildConstructorInvocation(entry.ImplementationTypeName, constructorArgs, initializerPart); writer.WriteLine($"var s0 = {constructorInvocation};"); @@ -254,13 +252,13 @@ private static void WriteFuncFactoryBody(SourceWriter writer, FuncRegistrationEn continue; var methodParams = method.Parameters ?? []; - var methodEntries = new List<(string Name, string? Value, bool NeedsConditional)>(methodParams.Length); + var methodEntries = new List<(string Name, string? Value)>(methodParams.Length); foreach(var param in methodParams) { var matchedArg = TryConsumeMatchingFuncInputArg(param.Type.Name, inputArgNames, inputArgTypeNames, inputArgUsed); if(matchedArg is not null) { - methodEntries.Add((param.Name, matchedArg, false)); + methodEntries.Add((param.Name, matchedArg)); continue; } @@ -282,11 +280,11 @@ private static void WriteFuncFactoryBody(SourceWriter writer, FuncRegistrationEn resolvedVar = ResolveParamAndEmitVar(writer, param, paramVar, isKeyedRegistration: false, registrationKey: null); } - methodEntries.Add((param.Name, resolvedVar, false)); + methodEntries.Add((param.Name, resolvedVar)); methodParamIndex++; } - var methodArgs = BuildArgumentListFromEntries([.. methodEntries]); + var methodArgs = BuildArgumentListFromEntries(methodEntries); writer.WriteLine($"s0.{method.Name}({methodArgs});"); } @@ -334,65 +332,6 @@ private static string BuildFuncLambdaParameters(ImmutableEquatableArray - /// Collects Func resolver entries for container code generation. - /// Only single-parameter Func<T> wrappers are tracked as field-backed wrapper resolvers. - /// - private static ImmutableEquatableArray CollectContainerFuncEntries( - ImmutableEquatableArray singletons, - ImmutableEquatableArray scoped, - ImmutableEquatableArray transients, - ImmutableEquatableDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> byServiceTypeAndKey) - { - var neededTypes = new HashSet(StringComparer.Ordinal); - - foreach(var lifetime in new[] { singletons, scoped, transients }) - { - foreach(var cached in lifetime) - { - var reg = cached.Registration; - ScanParamsForContainerFuncNeeds(reg.ImplementationType.ConstructorParameters, neededTypes); - ScanInjectionMembersForContainerFuncNeeds(reg.InjectionMembers, neededTypes); - } - } - - if(neededTypes.Count == 0) - return []; - - var entries = new List(); - var addedKeys = new HashSet(StringComparer.Ordinal); - - foreach(var kvp in byServiceTypeAndKey) - { - var serviceType = kvp.Key.ServiceType; - if(!neededTypes.Contains(serviceType)) - continue; - - foreach(var cached in kvp.Value) - { - var reg = cached.Registration; - if(reg.IsOpenGeneric) - continue; - - // Async-init services cannot be resolved synchronously β€” exclude from Func entries - if(cached.IsAsyncInit) - continue; - - var entryKey = $"{serviceType}|{reg.ImplementationType.Name}|{reg.Key}"; - if(!addedKeys.Add(entryKey)) - continue; - - var safeInnerType = GetSafeIdentifier(serviceType); - var safeImplType = GetSafeIdentifier(reg.ImplementationType.Name); - var fieldName = $"_func_{safeInnerType}_{safeImplType}"; - - entries.Add(new ContainerFuncEntry(serviceType, cached.ResolverMethodName, fieldName)); - } - } - - return entries.ToImmutableEquatableArray(); - } - /// /// Scans constructor parameters for single-parameter Func dependencies. /// @@ -457,113 +396,101 @@ private static void ScanTypeForContainerFuncNeeds(TypeData type, HashSet } /// - /// Writes Func field declarations and array resolvers for container generation. + /// Gets the array resolver method name for a Func wrapper type. /// - private static void WriteContainerFuncFields( - SourceWriter writer, - ImmutableEquatableArray entries) + private static string GetFuncArrayResolverMethodName(string innerServiceTypeName) { - if(entries.Length == 0) - return; + var safeInnerType = GetSafeIdentifier(innerServiceTypeName); + return $"GetAllFunc_{safeInnerType}_Array"; + } - writer.WriteLine("// Func wrapper fields"); - writer.WriteLine(); + private static ImmutableEquatableArray CreateFuncWrapperContainerEntries( + ImmutableEquatableArray singletons, + ImmutableEquatableArray scoped, + ImmutableEquatableArray transients, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup) + { + var neededTypes = new HashSet(StringComparer.Ordinal); - foreach(var entry in entries) + foreach(var lifetime in new[] { singletons, scoped, transients }) { - var wrapperTypeName = $"global::System.Func<{entry.InnerServiceTypeName}>"; - writer.WriteLine($"private readonly {wrapperTypeName} {entry.FieldName};"); + foreach(var cached in lifetime) + { + var reg = cached.Registration; + ScanParamsForContainerFuncNeeds(reg.ImplementationType.ConstructorParameters, neededTypes); + ScanInjectionMembersForContainerFuncNeeds(reg.InjectionMembers, neededTypes); + } } - writer.WriteLine(); - var grouped = entries - .GroupBy(static e => e.InnerServiceTypeName) - .ToList(); + if(neededTypes.Count == 0) + return []; - foreach(var group in grouped) - { - var innerServiceTypeName = group.Key; - var wrapperTypeName = $"global::System.Func<{innerServiceTypeName}>"; - var arrayMethodName = GetFuncArrayResolverMethodName(innerServiceTypeName); + var entries = new List<(string InnerServiceTypeName, string InnerImplTypeName, string FieldName, string ResolverMethodName, string? Key)>(); + var addedKeys = new HashSet(StringComparer.Ordinal); - writer.WriteLine($"private {wrapperTypeName}[] {arrayMethodName}() =>"); - writer.Indentation++; - writer.WriteLine("["); - writer.Indentation++; + foreach(var kvp in serviceLookup) + { + var serviceType = kvp.Key.ServiceType; + if(!neededTypes.Contains(serviceType)) + continue; - foreach(var entry in group) + foreach(var cached in kvp.Value) { - writer.WriteLine($"{entry.FieldName},"); - } + var reg = cached.Registration; + if(reg.IsOpenGeneric || cached.IsAsyncInit) + continue; - writer.Indentation--; - writer.WriteLine("];"); - writer.Indentation--; - writer.WriteLine(); - } - } + var entryKey = $"{serviceType}|{reg.ImplementationType.Name}|{reg.Key}"; + if(!addedKeys.Add(entryKey)) + continue; - /// - /// Writes field initialization statements for Func wrapper fields. - /// - private static void WriteContainerFuncFieldInitializations( - SourceWriter writer, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; + var safeInnerType = GetSafeIdentifier(serviceType); + var safeImplType = GetSafeIdentifier(reg.ImplementationType.Name); + var fieldName = $"_func_{safeInnerType}_{safeImplType}"; - writer.WriteLine(); - writer.WriteLine("// Initialize Func wrapper fields"); - foreach(var entry in entries) - { - var wrapperTypeName = $"global::System.Func<{entry.InnerServiceTypeName}>"; - writer.WriteLine($"{entry.FieldName} = new {wrapperTypeName}(() => {entry.ResolverMethodName}());"); + entries.Add((serviceType, reg.ImplementationType.Name, fieldName, cached.ResolverMethodName, reg.Key)); + } } - } - /// - /// Writes _localResolvers entries for Func wrapper services. - /// - private static void WriteContainerFuncLocalResolverEntries( - SourceWriter writer, - string containerTypeName, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; + if(entries.Count == 0) + return []; - writer.WriteLine(); - writer.WriteLine("// Func wrapper resolvers"); + var collectionFieldsByServiceType = entries + .GroupBy(static e => e.InnerServiceTypeName) + .ToDictionary( + static g => g.Key, + static g => g.Select(static e => e.FieldName).ToImmutableEquatableArray(), + StringComparer.Ordinal); - var grouped = entries + var collectionEmitterFieldByServiceType = entries .GroupBy(static e => e.InnerServiceTypeName) - .ToList(); + .ToDictionary( + static g => g.Key, + static g => g.Last().FieldName, + StringComparer.Ordinal); + + var wrapperEntries = new List(entries.Count); - foreach(var group in grouped) + foreach(var entry in entries) { - var innerServiceTypeName = group.Key; - var wrapperTypeName = $"global::System.Func<{innerServiceTypeName}>"; - - var lastEntry = group.Last(); - writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}), {KeyedServiceAnyKey}), static c => c.{lastEntry.FieldName}),"); - - var arrayMethodName = GetFuncArrayResolverMethodName(innerServiceTypeName); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}[]), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); + var emitCollectionResolver = string.Equals( + entry.FieldName, + collectionEmitterFieldByServiceType[entry.InnerServiceTypeName], + StringComparison.Ordinal); + var collectionFieldNames = emitCollectionResolver + ? collectionFieldsByServiceType[entry.InnerServiceTypeName] + : []; + + wrapperEntries.Add(new FuncWrapperContainerEntry( + entry.InnerServiceTypeName, + entry.InnerImplTypeName, + entry.FieldName, + entry.ResolverMethodName, + entry.Key, + emitCollectionResolver, + collectionFieldNames)); } - } - /// - /// Gets the array resolver method name for a Func wrapper type. - /// - private static string GetFuncArrayResolverMethodName(string innerServiceTypeName) - { - var safeInnerType = GetSafeIdentifier(innerServiceTypeName); - return $"GetAllFunc_{safeInnerType}_Array"; + return wrapperEntries.ToImmutableEquatableArray(); } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/GenerateRegisterOutput.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/GenerateRegisterOutput.cs new file mode 100644 index 0000000..258730a --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/GenerateRegisterOutput.cs @@ -0,0 +1,150 @@ +ο»Ώusing static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + private static void GenerateRegisterOutput( + in SourceProductionContext ctx, + RegisterOutputModel model) + { + var source = GenerateExtensionMethodSource(model); + ctx.AddSource($"{model.AssemblyName}.ServiceRegistration.g.cs", source); + } + + private static string GenerateExtensionMethodSource( + RegisterOutputModel model) + { + // Build the final output + var mainWriter = new SourceWriter(); + + mainWriter.WriteLine(AutoGeneratedHeader); + mainWriter.WriteLine(NullableEnable); + mainWriter.WriteLine(); + mainWriter.WriteLine("using Microsoft.Extensions.DependencyInjection;"); + mainWriter.WriteLine("using System.Collections.Generic;"); + mainWriter.WriteLine("using System.Linq;"); + mainWriter.WriteLine(); + + mainWriter.WriteLine($"namespace {GetSafeNamespace(model.RootNamespace)}"); + mainWriter.WriteLine("{"); + mainWriter.Indentation++; + + mainWriter.WriteLine("/// "); + mainWriter.WriteLine($"/// Extension methods for registering services from {model.AssemblyName}."); + mainWriter.WriteLine("/// "); + mainWriter.WriteLine($"public static class {model.MethodBaseName}ServiceCollectionExtensions"); + mainWriter.WriteLine("{"); + mainWriter.Indentation++; + + // Generate single method with tags parameter + mainWriter.WriteLine("/// "); + mainWriter.WriteLine("/// Registers services. Services with tags are only registered when matching tags are passed."); + mainWriter.WriteLine("/// "); + mainWriter.WriteLine("/// The service collection."); + mainWriter.WriteLine("/// Optional tags to filter which services to register."); + mainWriter.WriteLine($"public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Add{model.MethodBaseName}(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags)"); + mainWriter.WriteLine("{"); + mainWriter.Indentation++; + + var context = new RegisterWriteContext(model.AsyncInitServiceTypes); + + bool isFirstGroup = true; + foreach(var group in model.TagGroups) + { + if(!isFirstGroup) + { + mainWriter.WriteLine(); + } + isFirstGroup = false; + + if(group.Tags.Length == 0) + { + // No tags - only register when no tags passed (mutually exclusive model) + WriteNoTagConditionalBlock(mainWriter, group.Registrations, group.LazyEntries, group.FuncEntries, group.KvpEntries, context); + } + else + { + // Has tags - only register when tags match + WriteConditionalTagBlock(mainWriter, group.Tags, group.Registrations, group.LazyEntries, group.FuncEntries, group.KvpEntries, context); + } + } + + mainWriter.WriteLine(); + mainWriter.WriteLine("return services;"); + + mainWriter.Indentation--; + mainWriter.WriteLine("}"); + + mainWriter.Indentation--; + mainWriter.WriteLine("}"); + + mainWriter.Indentation--; + mainWriter.WriteLine("}"); + + return mainWriter.ToString(); + } + + /// + /// Writes a conditional block for tag-based registration. + /// Services with tags are only registered when the passed tags match. + /// + private static void WriteConditionalTagBlock( + SourceWriter writer, + ImmutableEquatableArray tags, + ImmutableEquatableArray registrations, + ImmutableEquatableArray lazyEntries, + ImmutableEquatableArray funcEntries, + ImmutableEquatableArray kvpEntries, + RegisterWriteContext context) + { + // Build the condition - only register when tags match + var tagConditions = tags.Select(static tag => $"tags.Contains({SymbolDisplay.FormatLiteral(tag, quote: true)})"); + var condition = string.Join(" || ", tagConditions); + + writer.WriteLine($"if ({condition})"); + writer.WriteLine("{"); + writer.Indentation++; + + foreach(var registration in registrations) + { + registration.WriteRegistration(writer, context); + } + + WriteLazyRegistrations(writer, lazyEntries); + WriteFuncRegistrations(writer, funcEntries); + WriteKvpRegistrations(writer, kvpEntries); + + writer.Indentation--; + writer.WriteLine("}"); + } + + /// + /// Writes a conditional block for services without tags. + /// Services without tags are only registered when no tags are passed (mutually exclusive model). + /// + private static void WriteNoTagConditionalBlock( + SourceWriter writer, + ImmutableEquatableArray registrations, + ImmutableEquatableArray lazyEntries, + ImmutableEquatableArray funcEntries, + ImmutableEquatableArray kvpEntries, + RegisterWriteContext context) + { + writer.WriteLine("if (!tags.Any())"); + writer.WriteLine("{"); + writer.Indentation++; + + foreach(var registration in registrations) + { + registration.WriteRegistration(writer, context); + } + + WriteLazyRegistrations(writer, lazyEntries); + WriteFuncRegistrations(writer, funcEntries); + WriteKvpRegistrations(writer, kvpEntries); + + writer.Indentation--; + writer.WriteLine("}"); + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/KvpRegistrationHelper.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/KvpRegistrationHelper.cs similarity index 53% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/KvpRegistrationHelper.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/KvpRegistrationHelper.cs index 2c0fe8c..98c78f4 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/KvpRegistrationHelper.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/KvpRegistrationHelper.cs @@ -14,12 +14,23 @@ partial class IocSourceGenerator /// The key literal expression (e.g., "Key1"). /// The service lifetime matching the keyed value service. /// The tags inherited from the source registration. - private readonly record struct KvpRegistrationEntry( - string KeyTypeName, - string ValueTypeName, - string KeyExpr, - ServiceLifetime Lifetime, - ImmutableEquatableArray Tags); + private readonly partial record struct KvpRegistrationEntry + { + public void WriteRegistration(SourceWriter writer) + { + var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{KeyTypeName}, {ValueTypeName}>"; + var lifetime = Lifetime.Name; + var resolveCall = $"sp.{GetRequiredKeyedService}<{ValueTypeName}>({KeyExpr})"; + + // KeyValuePair is a struct, so we cannot use AddSingleton (class constraint). + // Use ServiceDescriptor directly with a factory that boxes the struct. + writer.WriteLine( + $"services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(" + + $"typeof({kvpTypeName}), " + + $"({IServiceProviderGlobalTypeName} sp) => (object)new {kvpTypeName}({KeyExpr}, {resolveCall}), " + + $"global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.{lifetime}));"); + } + } /// /// Collects KeyValuePair registration entries needed by consumer dependencies. @@ -142,9 +153,9 @@ private static void ScanTypeForKvpNeeds( /// private static void WriteKvpRegistrations( SourceWriter writer, - List? entries) + ImmutableEquatableArray entries) { - if(entries is null or { Count: 0 }) + if(entries.Length == 0) return; writer.WriteLine(); @@ -152,31 +163,55 @@ private static void WriteKvpRegistrations( foreach(var entry in entries) { - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{entry.KeyTypeName}, {entry.ValueTypeName}>"; - var lifetime = entry.Lifetime.Name; - var resolveCall = $"sp.{GetRequiredKeyedService}<{entry.ValueTypeName}>({entry.KeyExpr})"; - - // KeyValuePair is a struct, so we cannot use AddSingleton (class constraint). - // Use ServiceDescriptor directly with a factory that boxes the struct. - writer.WriteLine( - $"services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(" + - $"typeof({kvpTypeName}), " + - $"({IServiceProviderGlobalTypeName} sp) => (object)new {kvpTypeName}({entry.KeyExpr}, {resolveCall}), " + - $"global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.{lifetime}));"); + entry.WriteRegistration(writer); } } /// - /// Collects KeyValuePair resolver entries for container code generation. - /// Scans all container registrations for KVP dependencies and finds matching keyed services. + /// Gets the array resolver method name for a KVP type pair. + /// + private static string GetKvpArrayResolverMethodName(string keyTypeName, string valueTypeName) + { + var safeKeyType = GetSafeIdentifier(keyTypeName); + var safeValueType = GetSafeIdentifier(valueTypeName); + return $"GetAllKvp_{safeKeyType}_{safeValueType}_Array"; + } + + /// + /// Gets the dictionary resolver method name for a KVP type pair. /// - private static ImmutableEquatableArray CollectContainerKvpEntries( - ImmutableEquatableArray singletons, - ImmutableEquatableArray scoped, - ImmutableEquatableArray transients, - ImmutableEquatableDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> byServiceTypeAndKey) + private static string GetKvpDictionaryResolverMethodName(string keyTypeName, string valueTypeName) + { + var safeKeyType = GetSafeIdentifier(keyTypeName); + var safeValueType = GetSafeIdentifier(valueTypeName); + return $"GetAllKvp_{safeKeyType}_{safeValueType}_Dictionary"; + } + + /// + /// Checks if a registration's key value type is compatible with the requested KVP key type. + /// + /// The key type name requested by the consumer (e.g., "string", "object"). + /// The actual key value type of the registration, or null if unknown. + /// True if the key types are compatible. + private static bool IsKeyTypeCompatible(string requestedKeyTypeName, TypeData? registrationKeyValueType) + { + // "object" key type accepts all key value types + if(string.Equals(requestedKeyTypeName, "object", StringComparison.Ordinal)) + return true; + + // If the registration's key value type is unknown (null), treat as object β€” only compatible with object + if(registrationKeyValueType is null) + return false; + + return string.Equals(registrationKeyValueType.Name, requestedKeyTypeName, StringComparison.Ordinal); + } + + private static ImmutableEquatableArray CreateKvpWrapperContainerEntries( + ImmutableEquatableArray singletons, + ImmutableEquatableArray scoped, + ImmutableEquatableArray transients, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup) { - // Step 1: Scan all registrations for KVP needs var neededPairs = new HashSet<(string KeyTypeName, string ValueTypeName)>(); foreach(var lifetime in new[] { singletons, scoped, transients }) @@ -192,11 +227,10 @@ private static ImmutableEquatableArray CollectContainerKvpEnt if(neededPairs.Count == 0) return []; - // Step 2: Find keyed registrations matching those value types - var entries = new List(); + var wrapperEntries = new List(); var addedKeys = new HashSet(StringComparer.Ordinal); - foreach(var kvp in byServiceTypeAndKey) + foreach(var kvp in serviceLookup) { var key = kvp.Key.Key; if(key is null) @@ -209,11 +243,7 @@ private static ImmutableEquatableArray CollectContainerKvpEnt if(!string.Equals(serviceType, valueTypeName, StringComparison.Ordinal)) continue; - var cached = kvp.Value[^1]; // Last wins - - // Filter by key value type compatibility: - // - "object" key type accepts all key value types - // - Otherwise, the registration's key value type must match exactly + var cached = kvp.Value[^1]; if(!IsKeyTypeCompatible(keyTypeName, cached.Registration.KeyValueType)) continue; @@ -226,7 +256,7 @@ private static ImmutableEquatableArray CollectContainerKvpEnt var safeValueType = GetSafeIdentifier(valueTypeName); var kvpResolverName = $"GetKvp_{safeKeyType}_{safeValueType}_{safeKey}"; - entries.Add(new ContainerKvpEntry( + wrapperEntries.Add(new KvpWrapperContainerEntry( keyTypeName, valueTypeName, key, @@ -235,186 +265,6 @@ private static ImmutableEquatableArray CollectContainerKvpEnt } } - return entries.ToImmutableEquatableArray(); - } - - /// - /// Writes KVP resolver methods and the array resolver for container generation. - /// Individual methods return a single KeyValuePair<K, V> entry. - /// The array resolver collects all KVP entries for GetServices<KVP<K, V>>. - /// - private static void WriteContainerKvpResolverMethods( - SourceWriter writer, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; - - writer.WriteLine("// KeyValuePair resolver methods"); - writer.WriteLine(); - - // Write individual KVP resolver methods - foreach(var entry in entries) - { - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{entry.KeyTypeName}, {entry.ValueTypeName}>"; - writer.WriteLine($"private {kvpTypeName} {entry.KvpResolverMethodName}() => new {kvpTypeName}({entry.KeyExpr}, {entry.ResolverMethodName}());"); - writer.WriteLine(); - } - - // Write array resolver methods grouped by (KeyTypeName, ValueTypeName) - var grouped = entries - .GroupBy(static e => (e.KeyTypeName, e.ValueTypeName)) - .ToList(); - - foreach(var group in grouped) - { - var (keyTypeName, valueTypeName) = group.Key; - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{keyTypeName}, {valueTypeName}>"; - var arrayMethodName = GetKvpArrayResolverMethodName(keyTypeName, valueTypeName); - - writer.WriteLine($"private {kvpTypeName}[] {arrayMethodName}() =>"); - writer.Indentation++; - writer.WriteLine("["); - writer.Indentation++; - - foreach(var entry in group) - { - writer.WriteLine($"{entry.KvpResolverMethodName}(),"); - } - - writer.Indentation--; - writer.WriteLine("];"); - writer.Indentation--; - writer.WriteLine(); - } - - // Write dictionary resolver methods grouped by (KeyTypeName, ValueTypeName) - foreach(var group in grouped) - { - var (keyTypeName, valueTypeName) = group.Key; - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{keyTypeName}, {valueTypeName}>"; - var dictionaryMethodName = GetKvpDictionaryResolverMethodName(keyTypeName, valueTypeName); - - writer.WriteLine($"private global::System.Collections.Generic.Dictionary<{keyTypeName}, {valueTypeName}> {dictionaryMethodName}() =>"); - writer.Indentation++; - writer.WriteLine($"new global::System.Collections.Generic.Dictionary<{keyTypeName}, {valueTypeName}>()"); - writer.WriteLine("{"); - writer.Indentation++; - - foreach(var entry in group) - { - writer.WriteLine($"[{entry.KeyExpr}] = {entry.ResolverMethodName}(),"); - } - - writer.Indentation--; - writer.WriteLine("};"); - writer.Indentation--; - writer.WriteLine(); - } - } - - /// - /// Writes _localResolvers entries for KVP services so that - /// GetServices<KeyValuePair<K, V>> can collect them. - /// - private static void WriteContainerKvpLocalResolverEntries( - SourceWriter writer, - string containerTypeName, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; - - writer.WriteLine(); - writer.WriteLine("// KeyValuePair resolvers"); - - // Write IEnumerable> entries grouped by (KeyTypeName, ValueTypeName) - var grouped = entries - .GroupBy(static e => (e.KeyTypeName, e.ValueTypeName)) - .ToList(); - - foreach(var group in grouped) - { - var (keyTypeName, valueTypeName) = group.Key; - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{keyTypeName}, {valueTypeName}>"; - var arrayMethodName = GetKvpArrayResolverMethodName(keyTypeName, valueTypeName); - var dictionaryMethodName = GetKvpDictionaryResolverMethodName(keyTypeName, valueTypeName); - - // IEnumerable, IReadOnlyCollection, ICollection β†’ Dictionary - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); - - // IReadOnlyList, IList, T[] β†’ Array - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{kvpTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof({kvpTypeName}[]), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - - // IReadOnlyDictionary, IDictionary, Dictionary β†’ Dictionary - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyDictionary<{keyTypeName}, {valueTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IDictionary<{keyTypeName}, {valueTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.Dictionary<{keyTypeName}, {valueTypeName}>), {KeyedServiceAnyKey}), static c => c.{dictionaryMethodName}()),"); - } - } - - /// - /// Gets the array resolver method name for a KVP type pair. - /// - private static string GetKvpArrayResolverMethodName(string keyTypeName, string valueTypeName) - { - var safeKeyType = GetSafeIdentifier(keyTypeName); - var safeValueType = GetSafeIdentifier(valueTypeName); - return $"GetAllKvp_{safeKeyType}_{safeValueType}_Array"; - } - - /// - /// Gets the dictionary resolver method name for a KVP type pair. - /// - private static string GetKvpDictionaryResolverMethodName(string keyTypeName, string valueTypeName) - { - var safeKeyType = GetSafeIdentifier(keyTypeName); - var safeValueType = GetSafeIdentifier(valueTypeName); - return $"GetAllKvp_{safeKeyType}_{safeValueType}_Dictionary"; - } - - /// - /// Checks if any keyed registrations exist for the given KVP key/value type pair. - /// Used to determine whether a KVP resolver method will be generated. - /// - private static bool HasKvpRegistrations(string keyTypeName, string valueTypeName, ContainerRegistrationGroups groups) - { - foreach(var kvp in groups.ByServiceTypeAndKey) - { - if(kvp.Key.Key is null) - continue; - - if(!string.Equals(kvp.Key.ServiceType, valueTypeName, StringComparison.Ordinal)) - continue; - - var cached = kvp.Value[^1]; - if(IsKeyTypeCompatible(keyTypeName, cached.Registration.KeyValueType)) - return true; - } - - return false; - } - - /// - /// Checks if a registration's key value type is compatible with the requested KVP key type. - /// - /// The key type name requested by the consumer (e.g., "string", "object"). - /// The actual key value type of the registration, or null if unknown. - /// True if the key types are compatible. - private static bool IsKeyTypeCompatible(string requestedKeyTypeName, TypeData? registrationKeyValueType) - { - // "object" key type accepts all key value types - if(string.Equals(requestedKeyTypeName, "object", StringComparison.Ordinal)) - return true; - - // If the registration's key value type is unknown (null), treat as object β€” only compatible with object - if(registrationKeyValueType is null) - return false; - - return string.Equals(registrationKeyValueType.Name, requestedKeyTypeName, StringComparison.Ordinal); + return wrapperEntries.ToImmutableEquatableArray(); } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/LazyRegistrationHelper.cs similarity index 55% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/LazyRegistrationHelper.cs index 78204d6..91233f7 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/LazyRegistrationHelper.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/LazyRegistrationHelper.cs @@ -11,11 +11,18 @@ partial class IocSourceGenerator /// The fully-qualified implementation type name. /// The service lifetime matching the inner service. /// The tags inherited from the source registration. - private readonly record struct LazyRegistrationEntry( - string InnerServiceTypeName, - string ImplementationTypeName, - ServiceLifetime Lifetime, - ImmutableEquatableArray Tags); + private readonly partial record struct LazyRegistrationEntry + { + public void WriteRegistration(SourceWriter writer) + { + var wrapperTypeName = $"global::System.Lazy<{InnerServiceTypeName}>"; + var lifetime = Lifetime.Name; + var resolveCall = $"sp.{GetRequiredService}<{ImplementationTypeName}>()"; + writer.WriteLine( + $"services.Add{lifetime}<{wrapperTypeName}>(({IServiceProviderGlobalTypeName} sp) => " + + $"new {wrapperTypeName}(() => {resolveCall}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication));"); + } + } /// /// Collects Lazy standalone registration entries needed by consumer dependencies. @@ -127,9 +134,9 @@ private static void ScanTypeForLazyNeeds( /// private static void WriteLazyRegistrations( SourceWriter writer, - List? entries) + ImmutableEquatableArray entries) { - if(entries is null or { Count: 0 }) + if(entries.Length == 0) return; writer.WriteLine(); @@ -137,23 +144,24 @@ private static void WriteLazyRegistrations( foreach(var entry in entries) { - var wrapperTypeName = $"global::System.Lazy<{entry.InnerServiceTypeName}>"; - var lifetime = entry.Lifetime.Name; - var resolveCall = $"sp.{GetRequiredService}<{entry.ImplementationTypeName}>()"; - writer.WriteLine( - $"services.Add{lifetime}<{wrapperTypeName}>(({IServiceProviderGlobalTypeName} sp) => " + - $"new {wrapperTypeName}(() => {resolveCall}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication));"); + entry.WriteRegistration(writer); } } /// - /// Collects Lazy resolver entries for container code generation. + /// Gets the array resolver method name for a Lazy wrapper type. /// - private static ImmutableEquatableArray CollectContainerLazyEntries( - ImmutableEquatableArray singletons, - ImmutableEquatableArray scoped, - ImmutableEquatableArray transients, - ImmutableEquatableDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> byServiceTypeAndKey) + private static string GetLazyArrayResolverMethodName(string innerServiceTypeName) + { + var safeInnerType = GetSafeIdentifier(innerServiceTypeName); + return $"GetAllLazy_{safeInnerType}_Array"; + } + + private static ImmutableEquatableArray CreateLazyWrapperContainerEntries( + ImmutableEquatableArray singletons, + ImmutableEquatableArray scoped, + ImmutableEquatableArray transients, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup) { var neededTypes = new HashSet(StringComparer.Ordinal); @@ -170,10 +178,10 @@ private static ImmutableEquatableArray CollectContainerLazyE if(neededTypes.Count == 0) return []; - var entries = new List(); + var entries = new List<(string InnerServiceTypeName, string InnerImplTypeName, string FieldName, string ResolverMethodName, string? Key)>(); var addedKeys = new HashSet(StringComparer.Ordinal); - foreach(var kvp in byServiceTypeAndKey) + foreach(var kvp in serviceLookup) { var serviceType = kvp.Key.ServiceType; if(!neededTypes.Contains(serviceType)) @@ -182,11 +190,7 @@ private static ImmutableEquatableArray CollectContainerLazyE foreach(var cached in kvp.Value) { var reg = cached.Registration; - if(reg.IsOpenGeneric) - continue; - - // Async-init services cannot be resolved synchronously β€” exclude from Lazy entries - if(cached.IsAsyncInit) + if(reg.IsOpenGeneric || cached.IsAsyncInit) continue; var entryKey = $"{serviceType}|{reg.ImplementationType.Name}|{reg.Key}"; @@ -197,121 +201,49 @@ private static ImmutableEquatableArray CollectContainerLazyE var safeImplType = GetSafeIdentifier(reg.ImplementationType.Name); var fieldName = $"_lazy_{safeInnerType}_{safeImplType}"; - entries.Add(new ContainerLazyEntry(serviceType, cached.ResolverMethodName, fieldName)); + entries.Add((serviceType, reg.ImplementationType.Name, fieldName, cached.ResolverMethodName, reg.Key)); } } - return entries.ToImmutableEquatableArray(); - } - - /// - /// Writes Lazy field declarations and array resolvers for container generation. - /// - private static void WriteContainerLazyFields( - SourceWriter writer, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; - - writer.WriteLine("// Lazy wrapper fields"); - writer.WriteLine(); - - foreach(var entry in entries) - { - var wrapperTypeName = $"global::System.Lazy<{entry.InnerServiceTypeName}>"; - writer.WriteLine($"private readonly {wrapperTypeName} {entry.FieldName};"); - } - writer.WriteLine(); + if(entries.Count == 0) + return []; - var grouped = entries + var collectionFieldsByServiceType = entries .GroupBy(static e => e.InnerServiceTypeName) - .ToList(); - - foreach(var group in grouped) - { - var innerServiceTypeName = group.Key; - var wrapperTypeName = $"global::System.Lazy<{innerServiceTypeName}>"; - var arrayMethodName = GetLazyArrayResolverMethodName(innerServiceTypeName); + .ToDictionary( + static g => g.Key, + static g => g.Select(static e => e.FieldName).ToImmutableEquatableArray(), + StringComparer.Ordinal); - writer.WriteLine($"private {wrapperTypeName}[] {arrayMethodName}() =>"); - writer.Indentation++; - writer.WriteLine("["); - writer.Indentation++; - - foreach(var entry in group) - { - writer.WriteLine($"{entry.FieldName},"); - } + var collectionEmitterFieldByServiceType = entries + .GroupBy(static e => e.InnerServiceTypeName) + .ToDictionary( + static g => g.Key, + static g => g.Last().FieldName, + StringComparer.Ordinal); - writer.Indentation--; - writer.WriteLine("];"); - writer.Indentation--; - writer.WriteLine(); - } - } + var wrapperEntries = new List(entries.Count); - /// - /// Writes field initialization statements for Lazy wrapper fields. - /// - private static void WriteContainerLazyFieldInitializations( - SourceWriter writer, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; - - writer.WriteLine(); - writer.WriteLine("// Initialize Lazy wrapper fields"); foreach(var entry in entries) { - var wrapperTypeName = $"global::System.Lazy<{entry.InnerServiceTypeName}>"; - writer.WriteLine($"{entry.FieldName} = new {wrapperTypeName}(() => {entry.ResolverMethodName}(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication);"); + var emitCollectionResolver = string.Equals( + entry.FieldName, + collectionEmitterFieldByServiceType[entry.InnerServiceTypeName], + StringComparison.Ordinal); + var collectionFieldNames = emitCollectionResolver + ? collectionFieldsByServiceType[entry.InnerServiceTypeName] + : []; + + wrapperEntries.Add(new LazyWrapperContainerEntry( + entry.InnerServiceTypeName, + entry.InnerImplTypeName, + entry.FieldName, + entry.ResolverMethodName, + entry.Key, + emitCollectionResolver, + collectionFieldNames)); } - } - /// - /// Writes _localResolvers entries for Lazy wrapper services. - /// - private static void WriteContainerLazyLocalResolverEntries( - SourceWriter writer, - string containerTypeName, - ImmutableEquatableArray entries) - { - if(entries.Length == 0) - return; - - writer.WriteLine(); - writer.WriteLine("// Lazy wrapper resolvers"); - - var grouped = entries - .GroupBy(static e => e.InnerServiceTypeName) - .ToList(); - - foreach(var group in grouped) - { - var innerServiceTypeName = group.Key; - var wrapperTypeName = $"global::System.Lazy<{innerServiceTypeName}>"; - - var lastEntry = group.Last(); - writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}), {KeyedServiceAnyKey}), static c => c.{lastEntry.FieldName}),"); - - var arrayMethodName = GetLazyArrayResolverMethodName(innerServiceTypeName); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{wrapperTypeName}>), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof({wrapperTypeName}[]), {KeyedServiceAnyKey}), static c => c.{arrayMethodName}()),"); - } - } - - /// - /// Gets the array resolver method name for a Lazy wrapper type. - /// - private static string GetLazyArrayResolverMethodName(string innerServiceTypeName) - { - var safeInnerType = GetSafeIdentifier(innerServiceTypeName); - return $"GetAllLazy_{safeInnerType}_Array"; + return wrapperEntries.ToImmutableEquatableArray(); } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterDecoratorWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterDecoratorWriters.cs new file mode 100644 index 0000000..db2ef9d --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterDecoratorWriters.cs @@ -0,0 +1,186 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Writes decorator pattern registration code. + /// + /// + /// Generates code like: + /// + /// services.AddSingleton<IMyService>((IServiceProvider sp) => + /// { + /// var s0 = sp.GetRequiredService<MyService>(); + /// var s1_p0 = sp.GetRequiredService<ILogger<MyServiceDecorator2>>(); + /// var s1 = new MyServiceDecorator2(s1_p0, s0); + /// var s2_p0 = sp.GetRequiredService<ILogger<MyServiceDecorator>>(); + /// var s2 = new MyServiceDecorator(s2_p0, s1); + /// return s2; + /// }); + /// + /// For open generic decorators, falls back to ActivatorUtilities.CreateInstance. + /// + private static void WriteDecoratorRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + var decorators = registration.Decorators; + var decoratorCount = decorators.Length; + + var serviceTypeParams = registration.ServiceType is GenericTypeData genericServiceType + ? genericServiceType.TypeParameters + : null; + var serviceTypeName = registration.ServiceType.Name; + + var serviceTypeNames = BuildServiceTypeNames(registration); + + writer.WriteServiceLambdaOpen(lifetime, serviceTypeName, registration.Key); + + writer.WriteLine("{"); + writer.Indentation++; + + var methodName = GetServiceResolutionMethod(registration.Key, isOptional: false); + var resolveCall = registration.Key is not null + ? $"sp.{methodName}<{registration.ImplementationType.Name}>(key)" + : $"sp.{methodName}<{registration.ImplementationType.Name}>()"; + writer.WriteLine($"var s0 = {resolveCall};"); + + for(int i = 0; i < decoratorCount; i++) + { + var decorator = decorators[decoratorCount - 1 - i]; + var prevVar = $"s{i}"; + var currentVar = $"s{i + 1}"; + + var decoratorTypeName = GetClosedDecoratorTypeName(decorator, serviceTypeParams); + var ctorParams = decorator.ConstructorParameters; + if(ctorParams is null || ctorParams.Length == 0) + { + writer.WriteLine($"var {currentVar} = global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<{decoratorTypeName}>(sp, {prevVar});"); + continue; + } + + bool isKeyedRegistration = registration.Key is not null; + WriteConstructInstanceWithInjection( + writer, + instanceVarName: currentVar, + implTypeName: decoratorTypeName, + constructorParams: ctorParams, + injectionMembers: decorator.InjectionMembers ?? [], + isKeyedRegistration: isKeyedRegistration, + registrationKey: registration.Key, + serviceTypeNames: serviceTypeNames, + ctorTypeNameResolver: t => decorator is GenericTypeData { IsOpenGeneric: true } && serviceTypeParams is not null ? SubstituteGenericArguments(t, decorator, serviceTypeParams) : t.Name, + memberTypeNameResolver: t => decorator is GenericTypeData { IsOpenGeneric: true } && serviceTypeParams is not null ? SubstituteGenericArguments(t, decorator, serviceTypeParams) : t.Name, + decoratedPrevVar: prevVar, + asyncInitServiceTypeNames: asyncInitServiceTypeNames); + } + + writer.WriteLine($"return s{decoratorCount};"); + + writer.Indentation--; + writer.WriteLine("});"); + } + + /// + /// Gets the closed decorator type name by substituting generic arguments if the decorator is an open generic. + /// + private static string GetClosedDecoratorTypeName(TypeData decorator, ImmutableEquatableArray? serviceTypeParams) + { + if(decorator is not GenericTypeData { IsOpenGeneric: true } genericDecorator) + { + return decorator.Name; + } + + if(serviceTypeParams is null || serviceTypeParams.Length == 0) + { + return decorator.Name; + } + + return $"{genericDecorator.NameWithoutGeneric}<{string.Join(", ", serviceTypeParams.Select(a => a.Type.Name))}>"; + } + + /// + /// Substitutes generic type parameters in a parameter type with actual generic arguments. + /// + private static string SubstituteGenericArguments(TypeData paramType, TypeData decorator, ImmutableEquatableArray serviceTypeParams) + { + var decoratorTypeParams = decorator is GenericTypeData genericDecorator + ? genericDecorator.TypeParameters + : null; + if(decoratorTypeParams is null || decoratorTypeParams.Length == 0) + { + return paramType.Name; + } + + if(serviceTypeParams.Length != decoratorTypeParams.Length) + { + return paramType.Name; + } + + var result = paramType.Name; + for(int i = 0; i < decoratorTypeParams.Length; i++) + { + result = ReplaceTypeParameter(result, decoratorTypeParams[i].ParameterName, serviceTypeParams[i].Type.Name); + } + + return result; + } + + /// + /// Builds a set of service type names for IsServiceParameter check. + /// + private static HashSet BuildServiceTypeNames(ServiceRegistrationModel registration) + { + var serviceTypeNames = new HashSet(StringComparer.Ordinal); + + AddTypeNameVariants(serviceTypeNames, registration.ServiceType); + AddTypeNameVariants(serviceTypeNames, registration.ImplementationType); + + if(registration.ImplementationType.AllBaseClasses is not null) + { + foreach(var baseClass in registration.ImplementationType.AllBaseClasses) + { + AddTypeNameVariants(serviceTypeNames, baseClass); + } + } + + if(registration.ImplementationType.AllInterfaces is not null) + { + foreach(var iface in registration.ImplementationType.AllInterfaces) + { + AddTypeNameVariants(serviceTypeNames, iface); + } + } + + return serviceTypeNames; + } + + /// + /// Adds both the full name and non-generic name variants to the set. + /// + private static void AddTypeNameVariants(HashSet set, TypeData type) + { + set.Add(type.Name); + if(type is GenericTypeData genericType && type.Name != genericType.NameWithoutGeneric) + { + set.Add(genericType.NameWithoutGeneric); + } + } + + /// + /// Checks if a parameter type matches any of the service types. + /// + private static bool IsServiceTypeParameter(TypeData paramType, string substitutedTypeName, HashSet serviceTypeNames) + { + if(serviceTypeNames.Contains(substitutedTypeName)) + { + return true; + } + + if(serviceTypeNames.Contains(paramType.Name)) + { + return true; + } + + return paramType is GenericTypeData genericParamType + && serviceTypeNames.Contains(genericParamType.NameWithoutGeneric); + } +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterEntry.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterEntry.cs new file mode 100644 index 0000000..d846a1f --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterEntry.cs @@ -0,0 +1,565 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Shared context for writing register entries. + /// + private readonly record struct RegisterWriteContext( + ImmutableEquatableSet? AsyncInitServiceTypeNames); + + /// + /// Base model for a single registration entry that can write itself. + /// + private abstract record class RegisterEntry(ServiceRegistrationModel Registration) + { + public abstract void WriteRegistration(SourceWriter writer, RegisterWriteContext context); + } + + /// + /// Registration entry for simple and open generic service registrations. + /// + private sealed record class SimpleRegisterEntry : RegisterEntry + { + public string ServiceTypeExpression { get; } + + public string ImplementationTypeExpression { get; } + + public bool IsKeyed { get; } + + public SimpleRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + IsKeyed = registration.Key is not null; + + if(registration.IsOpenGeneric) + { + ServiceTypeExpression = ConvertToTypeOf(registration.ServiceType); + ImplementationTypeExpression = ConvertToTypeOf(registration.ImplementationType); + return; + } + + ServiceTypeExpression = registration.ServiceType.Name; + ImplementationTypeExpression = registration.ImplementationType.Name; + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + var lifetime = Registration.Lifetime.Name; + + if(Registration.IsOpenGeneric) + { + if(IsKeyed) + { + writer.WriteLine($"services.AddKeyed{lifetime}({ServiceTypeExpression}, {Registration.Key}, {ImplementationTypeExpression});"); + } + else + { + writer.WriteLine($"services.Add{lifetime}({ServiceTypeExpression}, {ImplementationTypeExpression});"); + } + + return; + } + + if(IsKeyed) + { + writer.WriteLine($"services.AddKeyed{lifetime}<{ServiceTypeExpression}, {ImplementationTypeExpression}>({Registration.Key});"); + } + else + { + writer.WriteLine($"services.Add{lifetime}<{ServiceTypeExpression}, {ImplementationTypeExpression}>();"); + } + } + } + + /// + /// Registration entry for static instance registrations. + /// + private sealed record class InstanceRegisterEntry : RegisterEntry + { + public string InstanceExpression { get; } + + public bool IsKeyed { get; } + + public InstanceRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + InstanceExpression = registration.Instance!; + IsKeyed = registration.Key is not null; + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + var serviceTypeName = Registration.ServiceType.Name; + + if(IsKeyed) + { + writer.WriteLine($"services.AddKeyedSingleton<{serviceTypeName}>({Registration.Key}, {InstanceExpression});"); + } + else + { + writer.WriteLine($"services.AddSingleton<{serviceTypeName}>({InstanceExpression});"); + } + } + } + + /// + /// Registration entry for service-type forwarding registrations. + /// + private sealed record class ForwardingRegisterEntry : RegisterEntry + { + public string? Key { get; } + + public ForwardingRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + Key = registration.Key; + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + var serviceTypeName = Registration.ServiceType.Name; + var implTypeName = Registration.ImplementationType.Name; + var lifetime = Registration.Lifetime.Name; + var requiredResolutionMethod = GetServiceResolutionMethod(Key, isOptional: false); + bool isAsyncInit = context.AsyncInitServiceTypeNames?.Contains(implTypeName) == true; + + if(isAsyncInit) + { + var taskServiceTypeName = $"global::System.Threading.Tasks.Task<{serviceTypeName}>"; + var taskImplTypeName = $"global::System.Threading.Tasks.Task<{implTypeName}>"; + if(Key is not null) + { + var requiredCall = BuildServiceCall(requiredResolutionMethod, taskImplTypeName, serviceKey: "key"); + writer.WriteLine($"services.AddKeyed{lifetime}<{taskServiceTypeName}>({Key}, async ({IServiceProviderGlobalTypeName} sp, object? key) => await {requiredCall});"); + } + else + { + var requiredCall = BuildServiceCall(requiredResolutionMethod, taskImplTypeName, serviceKey: null); + writer.WriteLine($"services.Add{lifetime}<{taskServiceTypeName}>(async ({IServiceProviderGlobalTypeName} sp) => await {requiredCall});"); + } + + return; + } + + if(Key is not null) + { + var requiredCall = BuildServiceCall(requiredResolutionMethod, implTypeName, serviceKey: "key"); + writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}>({Key}, ({IServiceProviderGlobalTypeName} sp, object? key) => {requiredCall});"); + } + else + { + var requiredCall = BuildServiceCall(requiredResolutionMethod, implTypeName, serviceKey: null); + writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) => {requiredCall});"); + } + } + } + + /// + /// Additional factory parameter binding with pre-computed temporary variable name. + /// + private readonly record struct FactoryAdditionalParameter( + ParameterData Parameter, + string TemporaryVariableName); + + /// + /// Registration entry for factory method registrations. + /// + private sealed record class FactoryRegisterEntry : RegisterEntry + { + public FactoryMethodData FactoryMethodData { get; } + + public string Lifetime { get; } + + public string ServiceTypeName { get; } + + public string? Key { get; } + + public bool IsKeyedRegistration { get; } + + public bool HasServiceProvider { get; } + + public bool HasKey { get; } + + public bool NeedsCast { get; } + + public string FactoryCallPath { get; } + + public bool CanGenerate { get; } + + public ImmutableEquatableArray AdditionalParameters { get; } + + public FactoryRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + ServiceTypeName = registration.ServiceType.Name; + Lifetime = registration.Lifetime.Name; + Key = registration.Key; + IsKeyedRegistration = registration.Key is not null; + + FactoryMethodData = registration.Factory!; + HasServiceProvider = FactoryMethodData.HasServiceProvider; + HasKey = FactoryMethodData.HasKey; + NeedsCast = FactoryMethodData.ReturnTypeName is not null && FactoryMethodData.ReturnTypeName != ServiceTypeName; + + var genericTypeArgs = BuildGenericFactoryTypeArgs(FactoryMethodData, registration.ServiceType); + if(FactoryMethodData.TypeParameterCount > 0 && genericTypeArgs is null) + { + CanGenerate = false; + FactoryCallPath = FactoryMethodData.Path; + } + else + { + CanGenerate = true; + FactoryCallPath = genericTypeArgs is not null + ? $"{FactoryMethodData.Path}<{genericTypeArgs}>" + : FactoryMethodData.Path; + } + + var additional = new FactoryAdditionalParameter[FactoryMethodData.AdditionalParameters.Length]; + for(int i = 0; i < additional.Length; i++) + { + additional[i] = new FactoryAdditionalParameter( + FactoryMethodData.AdditionalParameters[i], + $"f_p{i}"); + } + + AdditionalParameters = additional.ToImmutableEquatableArray(); + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + if(!CanGenerate) + { + return; + } + + if(AdditionalParameters.Length > 0) + { + WriteRegistrationWithAdditionalParameters(writer, context); + return; + } + + var factoryInvocation = BuildFactoryInvocationExpression([]); + WriteFactoryRegistrationLine(writer, Lifetime, ServiceTypeName, Key, factoryInvocation); + } + + private void WriteRegistrationWithAdditionalParameters(SourceWriter writer, RegisterWriteContext context) + { + writer.WriteServiceLambdaOpen(Lifetime, ServiceTypeName, Key); + + writer.WriteLine("{"); + writer.Indentation++; + + var resolvedParameterNames = new List(AdditionalParameters.Length); + foreach(var additionalParam in AdditionalParameters) + { + var resolvedName = ResolveParamAndEmitVar( + writer, + additionalParam.Parameter, + additionalParam.TemporaryVariableName, + IsKeyedRegistration, + Key, + context.AsyncInitServiceTypeNames); + resolvedParameterNames.Add(resolvedName); + } + + var factoryInvocation = BuildFactoryInvocationExpression(resolvedParameterNames); + writer.WriteLine($"return {factoryInvocation};"); + + writer.Indentation--; + writer.WriteLine("});"); + } + + private string BuildFactoryInvocationExpression(List additionalArguments) + { + var args = new List(additionalArguments.Count + 2); + if(HasServiceProvider) + { + args.Add("sp"); + } + + if(HasKey && Key is not null) + { + args.Add(Key); + } + + if(additionalArguments.Count > 0) + { + args.AddRange(additionalArguments); + } + + var invocation = $"{FactoryCallPath}({string.Join(", ", args)})"; + if(NeedsCast) + { + invocation = $"({ServiceTypeName}){invocation}"; + } + + return invocation; + } + } + + /// + /// Registration entry for services requiring constructor/property/method injection. + /// + private sealed record class InjectionRegisterEntry : RegisterEntry + { + public string Lifetime { get; } + + public string ServiceTypeName { get; } + + public string ImplementationTypeName { get; } + + public string? Key { get; } + + public bool IsKeyedRegistration { get; } + + public ImmutableEquatableArray ConstructorParameters { get; } + + public ImmutableEquatableArray PropertyInjectionMembers { get; } + + public ImmutableEquatableArray MethodInjectionMembers { get; } + + public InjectionRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + Lifetime = registration.Lifetime.Name; + ServiceTypeName = registration.ServiceType.Name; + ImplementationTypeName = registration.ImplementationType.Name; + Key = registration.Key; + IsKeyedRegistration = registration.Key is not null; + ConstructorParameters = registration.ImplementationType.ConstructorParameters ?? []; + + var (properties, methods) = CategorizeInjectionMembers(registration.InjectionMembers); + PropertyInjectionMembers = properties?.ToImmutableEquatableArray() ?? []; + MethodInjectionMembers = methods?.ToImmutableEquatableArray() ?? []; + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + writer.WriteServiceLambdaOpen(Lifetime, ServiceTypeName, Key); + + writer.WriteLine("{"); + writer.Indentation++; + + WriteInjectionBody(writer, context, null); + + writer.WriteLine("return s0;"); + + writer.Indentation--; + writer.WriteLine("});"); + } + + internal void WriteInjectionBody( + SourceWriter writer, + RegisterWriteContext context, + ImmutableEquatableArray? asyncMethodInjectionMembers) + { + var constructorParamEntries = new List<(string Name, string? Value)>(ConstructorParameters.Length); + + for(int i = 0; i < ConstructorParameters.Length; i++) + { + var parameter = ConstructorParameters[i]; + var paramVarName = $"p{i}"; + var resolvedVar = ResolveParamAndEmitVar( + writer, + parameter, + paramVarName, + IsKeyedRegistration, + Key, + context.AsyncInitServiceTypeNames); + + constructorParamEntries.Add((parameter.Name, resolvedVar)); + } + + var propertyVarNames = new string[PropertyInjectionMembers.Length]; + for(int i = 0; i < PropertyInjectionMembers.Length; i++) + { + var member = PropertyInjectionMembers[i]; + var memberVarName = $"s0_p{i}"; + var memberType = member.Type; + var memberTypeName = memberType is null ? "object" : memberType.Name; + bool hasNonNullDefault = member.HasDefaultValue && !member.DefaultValueIsNull; + + ResolveMemberValue( + writer, + memberType, + memberTypeName, + memberVarName, + member.Key, + member.IsNullable, + hasNonNullDefault, + member.DefaultValue, + context.AsyncInitServiceTypeNames); + + propertyVarNames[i] = memberVarName; + } + + int memberParamIndex = PropertyInjectionMembers.Length; + var syncMethodResolutions = ResolveMethodResolutions(writer, MethodInjectionMembers, ref memberParamIndex, context); + var asyncMethodResolutions = asyncMethodInjectionMembers is { Length: > 0 } asyncMembers + ? ResolveMethodResolutions(writer, asyncMembers, ref memberParamIndex, context) + : []; + + var propertyInitializers = new string[propertyVarNames.Length]; + for(int i = 0; i < PropertyInjectionMembers.Length; i++) + { + propertyInitializers[i] = $"{PropertyInjectionMembers[i].Name} = {propertyVarNames[i]}"; + } + + var constructorArgs = BuildArgumentListFromEntries(constructorParamEntries); + var initializerPart = propertyInitializers.Length > 0 + ? $" {{ {string.Join(", ", propertyInitializers)} }}" + : ""; + var constructorInvocation = BuildConstructorInvocation(ImplementationTypeName, constructorArgs, initializerPart); + writer.WriteLine($"var s0 = {constructorInvocation};"); + + WriteMethodInvocations(writer, syncMethodResolutions, useAwait: false); + WriteMethodInvocations(writer, asyncMethodResolutions, useAwait: true); + } + + private List<(string MethodName, string?[] ParamVars, string[] ParamNames)> ResolveMethodResolutions( + SourceWriter writer, + ImmutableEquatableArray methodMembers, + ref int memberParamIndex, + RegisterWriteContext context) + { + var methodResolutions = new List<(string MethodName, string?[] ParamVars, string[] ParamNames)>(methodMembers.Length); + + foreach(var method in methodMembers) + { + var parameters = method.Parameters ?? []; + var paramVars = new string?[parameters.Length]; + var paramNames = new string[parameters.Length]; + + for(int i = 0; i < parameters.Length; i++) + { + var parameter = parameters[i]; + var paramVarName = $"s0_m{memberParamIndex}"; + paramNames[i] = parameter.Name; + + paramVars[i] = method.Key is not null + ? ResolveMethodParameterWithKey( + writer, + parameter, + paramVarName, + method.Key, + IsKeyedRegistration, + Key, + typeNameResolver: null, + context.AsyncInitServiceTypeNames) + : ResolveParamAndEmitVar( + writer, + parameter, + paramVarName, + IsKeyedRegistration, + Key, + context.AsyncInitServiceTypeNames); + + memberParamIndex++; + } + + methodResolutions.Add((method.Name, paramVars, paramNames)); + } + + return methodResolutions; + } + + private static void WriteMethodInvocations( + SourceWriter writer, + List<(string MethodName, string?[] ParamVars, string[] ParamNames)> methodResolutions, + bool useAwait) + { + foreach(var (methodName, paramVars, paramNames) in methodResolutions) + { + var entries = new (string Name, string? Value)[paramVars.Length]; + + for(int i = 0; i < paramVars.Length; i++) + { + entries[i] = (paramNames[i], paramVars[i]); + } + + string args = BuildArgumentListFromEntries(entries); + writer.WriteLine(useAwait + ? $"await s0.{methodName}({args});" + : $"s0.{methodName}({args});"); + } + } + } + + /// + /// Registration entry for services requiring async injection method initialization. + /// + private sealed record class AsyncInjectionRegisterEntry : RegisterEntry + { + private readonly InjectionRegisterEntry injectionEntry; + + public string TaskServiceTypeName { get; } + + public string TaskImplementationTypeName { get; } + + public ImmutableEquatableArray AsyncMethodInjectionMembers { get; } + + public AsyncInjectionRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + injectionEntry = new InjectionRegisterEntry(registration); + TaskServiceTypeName = $"global::System.Threading.Tasks.Task<{registration.ServiceType.Name}>"; + TaskImplementationTypeName = $"global::System.Threading.Tasks.Task<{registration.ImplementationType.Name}>"; + AsyncMethodInjectionMembers = registration.InjectionMembers + .Where(static member => member.MemberType == InjectionMemberType.AsyncMethod) + .ToImmutableEquatableArray(); + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + writer.WriteServiceLambdaOpen(Registration.Lifetime.Name, TaskServiceTypeName, Registration.Key); + + writer.WriteLine("{"); + writer.Indentation++; + + writer.WriteLine($"async {TaskImplementationTypeName} Init()"); + writer.WriteLine("{"); + writer.Indentation++; + + injectionEntry.WriteInjectionBody(writer, context, AsyncMethodInjectionMembers); + + writer.WriteLine("return s0;"); + + writer.Indentation--; + writer.WriteLine("}"); + + writer.WriteLine("return Init();"); + + writer.Indentation--; + writer.WriteLine("});"); + } + } + + /// + /// Registration entry for decorator-chain registrations. + /// + private sealed record class DecoratorRegisterEntry : RegisterEntry + { + public string Lifetime { get; } + + public ImmutableEquatableArray Decorators { get; } + + public FactoryMethodData? Factory { get; } + + public DecoratorRegisterEntry(ServiceRegistrationModel registration) + : base(registration) + { + Lifetime = registration.Lifetime.Name; + Decorators = registration.Decorators; + Factory = registration.Factory; + } + + public override void WriteRegistration(SourceWriter writer, RegisterWriteContext context) + { + WriteDecoratorRegistration(writer, Registration, Lifetime, context.AsyncInitServiceTypeNames); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterFactoryPatternWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterFactoryPatternWriters.cs new file mode 100644 index 0000000..ddf1dbd --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterFactoryPatternWriters.cs @@ -0,0 +1,25 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Writes a factory registration line for keyed or non-keyed services. + /// + private static void WriteFactoryRegistrationLine( + SourceWriter writer, + string lifetime, + string serviceTypeName, + string? registrationKey, + string factoryInvocation) + { + if(registrationKey is not null) + { + writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}>({registrationKey}, ({IServiceProviderGlobalTypeName} sp, object? key) => {factoryInvocation});"); + return; + } + + writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) => {factoryInvocation});"); + } +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterInstanceWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterInstanceWriters.cs new file mode 100644 index 0000000..35329af --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterInstanceWriters.cs @@ -0,0 +1,199 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Shared helper to construct an instance and apply property/field/method injection. + /// Supports decorator scenarios via service-parameter detection and generic type substitution. + /// + private static void WriteConstructInstanceWithInjection( + SourceWriter writer, + string instanceVarName, + string implTypeName, + ImmutableEquatableArray? constructorParams, + ImmutableEquatableArray injectionMembers, + bool isKeyedRegistration, + string? registrationKey, + HashSet? serviceTypeNames, + Func? ctorTypeNameResolver, + Func? memberTypeNameResolver, + string? decoratedPrevVar, + ImmutableEquatableSet? asyncInitServiceTypeNames = null, + bool isAsyncMode = false) + { + var ctorParams = constructorParams ?? []; + var constructorParamEntries = new List<(string Name, string? Value)>(ctorParams.Length); + int paramIndex = 0; + + foreach(var param in ctorParams) + { + var resolvedTypeName = ctorTypeNameResolver is not null ? ctorTypeNameResolver(param.Type) : param.Type.Name; + if(decoratedPrevVar is not null && serviceTypeNames is not null && IsServiceTypeParameter(param.Type, resolvedTypeName, serviceTypeNames)) + { + constructorParamEntries.Add((param.Name, decoratedPrevVar)); + } + else + { + var varName = decoratedPrevVar is not null ? $"{instanceVarName}_p{paramIndex}" : $"p{paramIndex}"; + var resolvedVar = ResolveParamAndEmitVar(writer, param, varName, isKeyedRegistration, registrationKey, asyncInitServiceTypeNames); + constructorParamEntries.Add((param.Name, resolvedVar)); + paramIndex++; + } + } + + EmitConstruction( + writer, + instanceVarName, + implTypeName, + constructorParamEntries, + injectionMembers, + isKeyedRegistration, + registrationKey, + memberTypeNameResolver, + asyncInitServiceTypeNames, + isAsyncMode); + } + + private static void EmitConstruction( + SourceWriter writer, + string instanceVarName, + string implTypeName, + List<(string Name, string? Value)> constructorParamEntries, + ImmutableEquatableArray injectionMembers, + bool isKeyedRegistration, + string? registrationKey, + Func? memberTypeNameResolver, + ImmutableEquatableSet? asyncInitServiceTypeNames = null, + bool isAsyncMode = false) + { + List preProps = []; + int pfIdxCounter = 0; + foreach(var m in injectionMembers) + { + if(m.MemberType is not (InjectionMemberType.Property or InjectionMemberType.Field)) + continue; + + var varN = $"{instanceVarName}_p{pfIdxCounter}"; + var mt = m.Type; + var mtName = mt is null ? "object" : (memberTypeNameResolver is not null ? memberTypeNameResolver(mt) : mt.Name); + bool hasNonNullDefault = m.HasDefaultValue && !m.DefaultValueIsNull; + ResolveMemberValue(writer, mt, mtName, varN, m.Key, m.IsNullable, hasNonNullDefault, m.DefaultValue, asyncInitServiceTypeNames); + preProps.Add($"{m.Name} = {varN}"); + pfIdxCounter++; + } + + int memberParamIndex = pfIdxCounter; + var methodParamResolutions = ResolveMethodParamResolutions( + writer, + injectionMembers, + InjectionMemberType.Method, + instanceVarName, + ref memberParamIndex, + isKeyedRegistration, + registrationKey, + memberTypeNameResolver, + asyncInitServiceTypeNames); + + var asyncMethodParamResolutions = isAsyncMode + ? ResolveMethodParamResolutions( + writer, + injectionMembers, + InjectionMemberType.AsyncMethod, + instanceVarName, + ref memberParamIndex, + isKeyedRegistration, + registrationKey, + memberTypeNameResolver, + asyncInitServiceTypeNames) + : []; + + var constructorArgs = BuildArgumentListFromEntries(constructorParamEntries); + var initializerPart = preProps.Count > 0 ? $" {{ {string.Join(", ", preProps)} }}" : ""; + var constructorInvocation = BuildConstructorInvocation(implTypeName, constructorArgs, initializerPart); + writer.WriteLine($"var {instanceVarName} = {constructorInvocation};"); + + EmitMethodInvocations(writer, instanceVarName, methodParamResolutions, useAwait: false); + EmitMethodInvocations(writer, instanceVarName, asyncMethodParamResolutions, useAwait: true); + } + + /// + /// Resolves method parameters of a given and emits their variable declarations. + /// Shared by sync () and async () resolution loops. + /// + private static List<(string MethodName, string?[] ParamVars, string[] ParamNames)> ResolveMethodParamResolutions( + SourceWriter writer, + ImmutableEquatableArray injectionMembers, + InjectionMemberType targetType, + string instanceVarName, + ref int memberParamIndex, + bool isKeyedRegistration, + string? registrationKey, + Func? memberTypeNameResolver, + ImmutableEquatableSet? asyncInitServiceTypeNames) + { + var resolutions = new List<(string MethodName, string?[] ParamVars, string[] ParamNames)>(); + foreach(var method in injectionMembers) + { + if(method.MemberType != targetType) + continue; + + var mParams = method.Parameters ?? []; + var mVars = new string?[mParams.Length]; + var mNames = new string[mParams.Length]; + int mi = 0; + foreach(var p in mParams) + { + var pVar = $"{instanceVarName}_m{memberParamIndex}"; + mNames[mi] = p.Name; + if(method.Key is not null) + { + mVars[mi] = ResolveMethodParameterWithKey( + writer, + p, + pVar, + method.Key, + isKeyedRegistration, + registrationKey, + memberTypeNameResolver, + asyncInitServiceTypeNames); + } + else + { + mVars[mi] = ResolveParamAndEmitVar(writer, p, pVar, isKeyedRegistration, registrationKey, asyncInitServiceTypeNames); + } + + mi++; + memberParamIndex++; + } + + resolutions.Add((method.Name, mVars, mNames)); + } + + return resolutions; + } + + /// + /// Emits method invocation statements for the given . + /// When is , each call is prefixed with await. + /// + private static void EmitMethodInvocations( + SourceWriter writer, + string instanceVarName, + List<(string MethodName, string?[] ParamVars, string[] ParamNames)> resolutions, + bool useAwait) + { + foreach(var (mName, mVars, mNames) in resolutions) + { + var entries = new (string Name, string? Value)[mVars.Length]; + for(int i = 0; i < mVars.Length; i++) + { + entries[i] = (mNames[i], mVars[i]); + } + + var args = BuildArgumentListFromEntries(entries); + writer.WriteLine(useAwait + ? $"await {instanceVarName}.{mName}({args});" + : $"{instanceVarName}.{mName}({args});"); + } + } +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterOutputModel.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterOutputModel.cs new file mode 100644 index 0000000..cb061bc --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterOutputModel.cs @@ -0,0 +1,14 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Top-level output model for register source generation. + /// + private sealed record class RegisterOutputModel( + string MethodBaseName, + string RootNamespace, + string AssemblyName, + ImmutableEquatableArray TagGroups, + ImmutableEquatableSet? AsyncInitServiceTypes); +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterResolutionWriters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterResolutionWriters.cs new file mode 100644 index 0000000..2fb231b --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/RegisterResolutionWriters.cs @@ -0,0 +1,379 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Resolves a property or field injection value and emits its variable declaration. + /// + private static void ResolveMemberValue( + SourceWriter writer, + TypeData? memberType, + string memberTypeName, + string paramVar, + string? serviceKey, + bool isNullable, + bool hasNonNullDefault, + string? defaultValue, + ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + if(memberType is CollectionWrapperTypeData) + { + WriteCollectionResolution(writer, memberType, paramVar, serviceKey, isOptional: isNullable); + return; + } + + if(memberType is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData or TaskTypeData) + { + WriteWrapperResolution(writer, memberType, paramVar, serviceKey, isOptional: isNullable, asyncInitServiceTypeNames); + return; + } + + if(hasNonNullDefault) + { + var defExpr = defaultValue ?? "default"; + var methodName = GetServiceResolutionMethod(serviceKey, isOptional: true); + var svcCall = BuildServiceCall(methodName, memberTypeName, serviceKey); + writer.WriteLine($"var {paramVar} = {svcCall} ?? {defExpr};"); + return; + } + + var resolutionMethod = GetServiceResolutionMethod(serviceKey, isOptional: isNullable); + var call = BuildServiceCall(resolutionMethod, memberTypeName, serviceKey); + writer.WriteLine($"var {paramVar} = {call};"); + } + + /// + /// Resolves a keyed method parameter and emits its variable declaration. + /// + private static string? ResolveMethodParameterWithKey( + SourceWriter writer, + ParameterData param, + string paramVar, + string methodKey, + bool isKeyedRegistration, + string? registrationKey, + Func? typeNameResolver, + ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, methodKey, isOptional: false, typeNameResolver, asyncInitServiceTypeNames, out var resolvedVar)) + { + return resolvedVar!; + } + + var resolvedTypeName = typeNameResolver is not null ? typeNameResolver(param.Type) : param.Type.Name; + if(param.HasDefaultValue) + { + var defExpr = param.DefaultValue ?? "default"; + var optionalCall = BuildServiceCall(GetServiceResolutionMethod(methodKey, isOptional: true), resolvedTypeName, methodKey); + writer.WriteLine($"var {paramVar} = {optionalCall} ?? {defExpr};"); + return paramVar; + } + + var requiredCall = BuildServiceCall(GetServiceResolutionMethod(methodKey, isOptional: false), resolvedTypeName, methodKey); + writer.WriteLine($"var {paramVar} = {requiredCall};"); + return paramVar; + } + + /// + /// Writes parameter resolution code for [ServiceKey] attribute. + /// When the service is registered as keyed, injects the registration key with appropriate type casting. + /// When the service is not keyed, injects null. + /// + private static string WriteServiceKeyParameterResolution( + SourceWriter writer, + string paramVar, + string paramTypeName, + bool isKeyedRegistration, + string? registrationKey) + { + if(isKeyedRegistration && registrationKey is not null) + { + writer.WriteLine($"var {paramVar} = {registrationKey};"); + } + else + { + writer.WriteLine($"var {paramVar} = default({paramTypeName});"); + } + + return paramVar; + } + + /// + /// Resolve a parameter and emit its variable declaration. + /// Produces lines like: + /// var p = sp.GetService() ?? ; + /// var p = sp.GetKeyedService(key) ?? ; + /// or returns "sp" for IServiceProvider parameters (no var emitted). + /// + private static string ResolveParamAndEmitVar( + SourceWriter writer, + ParameterData param, + string paramVar, + bool isKeyedRegistration, + string? registrationKey = null, + ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + var paramTypeName = param.Type.Name; + + if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, param.ServiceKey, param.IsOptional, typeNameResolver: null, asyncInitServiceTypeNames, out var resolvedVar)) + { + return resolvedVar!; + } + + var isOptional = param.HasDefaultValue || param.IsOptional; + var methodName = GetServiceResolutionMethod(param.ServiceKey, isOptional); + var svcCall = BuildServiceCall(methodName, paramTypeName, param.ServiceKey); + + if(isOptional) + { + var defExpr = param.HasDefaultValue ? (param.DefaultValue ?? "default") : "default"; + writer.WriteLine($"var {paramVar} = {svcCall} ?? {defExpr};"); + return paramVar; + } + + writer.WriteLine($"var {paramVar} = {svcCall};"); + return paramVar; + } + + /// + /// Writes collection resolution code for constructor parameters and injection members. + /// Handles all collection types including IEnumerable<T>, IList<T>, T[], IReadOnlyList<T>, etc. + /// + private static void WriteCollectionResolution( + SourceWriter writer, + TypeData type, + string paramVar, + string? serviceKey, + bool isOptional = false) + { + var servicesMethod = serviceKey is not null ? GetKeyedServices : GetServices; + + switch(type) + { + case EnumerableTypeData enumerable: + { + var call = BuildServiceCall(servicesMethod, enumerable.ElementType.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call};"); + return; + } + + case ReadOnlyCollectionTypeData readOnlyCollection: + { + var call = BuildServiceCall(servicesMethod, readOnlyCollection.ElementType.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call}.ToArray();"); + return; + } + + case CollectionTypeData collection: + { + var call = BuildServiceCall(servicesMethod, collection.ElementType.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call}.ToArray();"); + return; + } + + case ReadOnlyListTypeData readOnlyList: + { + var call = BuildServiceCall(servicesMethod, readOnlyList.ElementType.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call}.ToArray();"); + return; + } + + case ListTypeData list: + { + var call = BuildServiceCall(servicesMethod, list.ElementType.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call}.ToArray();"); + return; + } + + case ArrayTypeData array: + { + var call = BuildServiceCall(servicesMethod, array.ElementType.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call}.ToArray();"); + return; + } + + default: + { + var methodName = GetServiceResolutionMethod(serviceKey, isOptional); + var call = BuildServiceCall(methodName, type.Name, serviceKey); + writer.WriteLine($"var {paramVar} = {call};"); + return; + } + } + } + + /// + /// Writes wrapper type resolution code for Lazy<T>, Func<T>, IDictionary<TKey, TValue>, + /// and KeyValuePair<TKey, TValue>. Supports nested wrapper types (e.g., Lazy<KeyValuePair<K, V>>). + /// + private static void WriteWrapperResolution( + SourceWriter writer, + TypeData type, + string paramVar, + string? serviceKey, + bool isOptional = false, + ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + var expr = BuildWrapperExpression(type, serviceKey, isOptional, asyncInitServiceTypeNames); + writer.WriteLine($"var {paramVar} = {expr};"); + } + + /// + /// Builds an inline wrapper expression. Recursively handles nested wrappers. + /// + private static string BuildWrapperExpression(TypeData type, string? serviceKey, bool isOptional, ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + switch(type) + { + case LazyTypeData lazy: + { + var innerType = lazy.InstanceType; + if(innerType is not WrapperTypeData) + { + var methodName = GetServiceResolutionMethod(serviceKey, isOptional); + return BuildServiceCall(methodName, type.Name, serviceKey); + } + + var lazyInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); + return $"new global::System.Lazy<{innerType.Name}>(() => {lazyInnerExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; + } + + case FuncTypeData func: + { + var innerType = func.ReturnType; + if(func.HasInputParameters || innerType is not WrapperTypeData) + { + var methodName = GetServiceResolutionMethod(serviceKey, isOptional); + return BuildServiceCall(methodName, type.Name, serviceKey); + } + + var funcInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); + return $"new global::System.Func<{innerType.Name}>(() => {funcInnerExpr})"; + } + + case KeyValuePairTypeData kvp: + { + var keyExpr = serviceKey ?? "default"; + var valueExpr = BuildInnerResolutionExpression(kvp.ValueType, serviceKey, isOptional, asyncInitServiceTypeNames); + return $"new global::System.Collections.Generic.KeyValuePair<{kvp.KeyType.Name}, {kvp.ValueType.Name}>({keyExpr}, {valueExpr})"; + } + + case DictionaryTypeData dict: + { + var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{dict.KeyType.Name}, {dict.ValueType.Name}>"; + var getServicesCall = BuildServiceCall(serviceKey is not null ? GetKeyedServices : GetServices, kvpTypeName, serviceKey); + return $"{getServicesCall}.ToDictionary()"; + } + + case TaskTypeData task: + { + var innerTypeName = task.InnerType.Name; + if(asyncInitServiceTypeNames?.Contains(innerTypeName) == true) + { + var methodName = GetServiceResolutionMethod(serviceKey, isOptional); + return BuildServiceCall(methodName, type.Name, serviceKey); + } + + var syncMethodName = GetServiceResolutionMethod(serviceKey, isOptional); + var syncCall = BuildServiceCall(syncMethodName, innerTypeName, serviceKey); + return $"global::System.Threading.Tasks.Task.FromResult({syncCall})"; + } + + default: + { + var methodName = GetServiceResolutionMethod(serviceKey, isOptional); + return BuildServiceCall(methodName, type.Name, serviceKey); + } + } + } + + /// + /// Builds an inner resolution expression β€” either a nested wrapper expression, a collection + /// expression, or a direct service call. Supports nesting such as Lazy<IEnumerable<T>>. + /// + private static string BuildInnerResolutionExpression(TypeData innerType, string? serviceKey, bool isOptional, ImmutableEquatableSet? asyncInitServiceTypeNames = null) + { + if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) + { + return BuildWrapperExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); + } + + if(innerType is CollectionWrapperTypeData collectionInner) + { + var getServicesCall = BuildServiceCall( + serviceKey is not null ? GetKeyedServices : GetServices, + collectionInner.ElementType.Name, + serviceKey); + + return innerType is CollectionWrapperTypeData and not EnumerableTypeData + ? $"{getServicesCall}.ToArray()" + : getServicesCall; + } + + var methodName = GetServiceResolutionMethod(serviceKey, isOptional); + return BuildServiceCall(methodName, innerType.Name, serviceKey); + } + + /// + /// Builds a service resolution call for keyed or non-keyed services. + /// + private static string BuildServiceCall(string methodName, string typeName, string? serviceKey) => + serviceKey is not null + ? $"sp.{methodName}<{typeName}>({serviceKey})" + : $"sp.{methodName}<{typeName}>()"; + + /// + /// Attempts to resolve common parameter cases and emit any required variable declarations. + /// Returns true when a resolution was produced. + /// + private static bool TryResolveCommonParameter( + SourceWriter writer, + ParameterData param, + string paramVar, + bool isKeyedRegistration, + string? registrationKey, + string? serviceKey, + bool isOptional, + Func? typeNameResolver, + ImmutableEquatableSet? asyncInitServiceTypeNames, + out string? resolvedVar) + { + if(IsServiceProviderType(param.Type.Name)) + { + resolvedVar = "sp"; + return true; + } + + var resolvedTypeName = typeNameResolver is not null ? typeNameResolver(param.Type) : param.Type.Name; + if(param.HasServiceKeyAttribute) + { + resolvedVar = WriteServiceKeyParameterResolution(writer, paramVar, resolvedTypeName, isKeyedRegistration, registrationKey); + return true; + } + + if(param.Type is CollectionWrapperTypeData) + { + WriteCollectionResolution(writer, param.Type, paramVar, serviceKey, isOptional); + resolvedVar = paramVar; + return true; + } + + if(param.Type is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData or TaskTypeData) + { + WriteWrapperResolution(writer, param.Type, paramVar, serviceKey, isOptional, asyncInitServiceTypeNames); + resolvedVar = paramVar; + return true; + } + + resolvedVar = null; + return false; + } + + /// + /// Checks if the type name represents System.IServiceProvider. + /// + private static bool IsServiceProviderType(string typeName) => + typeName is IServiceProviderGlobalTypeName or IServiceProviderTypeName or "IServiceProvider"; +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/WrapperRegistrationEntries.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/WrapperRegistrationEntries.cs new file mode 100644 index 0000000..8243678 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Register/WrapperRegistrationEntries.cs @@ -0,0 +1,55 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Represents a Lazy standalone service registration entry to be generated. + /// + /// The fully-qualified service type name. + /// The fully-qualified implementation type name. + /// The service lifetime matching the inner service. + /// The tags inherited from the source registration. + private readonly partial record struct LazyRegistrationEntry( + string InnerServiceTypeName, + string ImplementationTypeName, + ServiceLifetime Lifetime, + ImmutableEquatableArray Tags); + + /// + /// Represents a Func standalone service registration entry to be generated. + /// + /// The full Func service type name (e.g. Func<string, IService>). + /// The fully-qualified return service type name. + /// The fully-qualified implementation type name. + /// The service lifetime matching the inner service. + /// The implementation constructor parameters. + /// The implementation injection members. + /// The Func input type parameters. + /// The tags inherited from the source registration. + private readonly partial record struct FuncRegistrationEntry( + string FuncServiceTypeName, + string InnerServiceTypeName, + string ImplementationTypeName, + ServiceLifetime Lifetime, + ImmutableEquatableArray? ImplementationTypeConstructorParams, + ImmutableEquatableArray ImplementationTypeInjectionMembers, + ImmutableEquatableArray InputTypes, + ImmutableEquatableArray Tags); + + /// + /// Represents a KeyValuePair service registration entry to be generated. + /// Used when consumers depend on KeyValuePair<K, V>, IDictionary<K, V>, + /// or IEnumerable<KeyValuePair<K, V>>. + /// + /// The fully-qualified key type name (e.g., string). + /// The fully-qualified value type name (e.g., global::TestNamespace.IService). + /// The key literal expression (e.g., "Key1"). + /// The service lifetime matching the keyed value service. + /// The tags inherited from the source registration. + private readonly partial record struct KvpRegistrationEntry( + string KeyTypeName, + string ValueTypeName, + string KeyExpr, + ServiceLifetime Lifetime, + ImmutableEquatableArray Tags); +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/CodeGenHelpers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/CodeGenHelpers.cs new file mode 100644 index 0000000..1303711 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/CodeGenHelpers.cs @@ -0,0 +1,186 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Builds a constructor invocation expression with an optional initializer. + /// + private static string BuildConstructorInvocation(string implTypeName, string args, string initializerPart) => + $"new {implTypeName}({args}){initializerPart}"; + + /// + /// Builds an argument list from parameter entries. + /// Null values are omitted so the target member uses its optional parameter default. + /// + private static string BuildArgumentListFromEntries(IEnumerable<(string Name, string? Value)> entries) + { + // Check if any parameter uses default value + bool hasDefaultValue = entries.Any(e => e.Value is null); + + if(!hasDefaultValue) + { + // All values present - use positional arguments + return string.Join(", ", entries.Select(e => e.Value!)); + } + + // Some values are null - use named arguments for non-null values only + var namedArgs = entries.Where(e => e.Value is not null).Select(e => $"{e.Name}: {e.Value}"); + return string.Join(", ", namedArgs); + } + + /// + /// Converts a TypeData to typeof() syntax for open generic types. + /// For example: TypeData with Name="global::Namespace.GenericTest<T>" becomes "typeof(global::Namespace.GenericTest<>)" + /// + private static string ConvertToTypeOf(TypeData typeData) + { + if(typeData is not GenericTypeData genericTypeData || genericTypeData.GenericArity == 0) + { + return $"typeof({typeData.Name})"; + } + + // Build the open generic typeof + return $"typeof({genericTypeData.NameWithoutGeneric}{GetGenericString(genericTypeData.GenericArity)})"; + } + + /// + /// Cached generic arity strings to avoid repeated allocations. + /// + private static readonly string[] s_genericArityStrings = + [ + "<>", + "<,>", + "<,,>", + "<,,,>", + "<,,,,>", + "<,,,,,>", + "<,,,,,,>", + "<,,,,,,,>", + "<,,,,,,,,>" + ]; + + private static string GetGenericString(int arity) => + arity <= 9 ? s_genericArityStrings[arity - 1] : '<' + new string(',', arity - 1) + '>'; + + private static string GetServiceResolutionMethod(string? serviceKey, bool isOptional) => + (serviceKey is not null, isOptional) switch + { + (true, true) => "GetKeyedService", + (true, false) => "GetRequiredKeyedService", + (false, true) => "GetService", + (false, false) => "GetRequiredService", + }; + + /// + /// Builds the generic type arguments string for a generic factory method. + /// Uses the to map placeholder types in the service type template + /// to the actual types from the closed service type. + /// + /// The factory method data containing the generic type mapping. + /// The closed service type to extract type arguments from. + /// The generic type arguments string (e.g., "Entity, Dto"), or null if not a generic factory. + /// + /// Given: + /// - ServiceTypeTemplate: IRequestHandler<Task<int>> + /// - PlaceholderToTypeParamMap: { "int" -> 0 } + /// - ClosedServiceType: IRequestHandler<Task<Entity>> + /// Returns: "Entity" + /// + private static string? BuildGenericFactoryTypeArgs(FactoryMethodData factory, TypeData closedServiceType) + { + var mapping = factory.GenericTypeMapping; + if(mapping is null || factory.TypeParameterCount == 0) + { + return null; + } + + var template = mapping.ServiceTypeTemplate; + var placeholderMap = mapping.PlaceholderToTypeParamMap; + + // Build a map from placeholder types to actual types by comparing template with closed service type + var placeholderToActualType = new Dictionary(StringComparer.Ordinal); + ExtractPlaceholderMappings(template, closedServiceType, placeholderToActualType); + + // Build type arguments array in the order of factory method's type parameters + var typeArgs = new string[factory.TypeParameterCount]; + foreach(var kvp in placeholderMap) + { + var placeholderTypeName = kvp.Key; + var typeParamIndex = kvp.Value; + + if(typeParamIndex < 0 || typeParamIndex >= typeArgs.Length) + { + continue; + } + + if(placeholderToActualType.TryGetValue(placeholderTypeName, out var actualTypeName)) + { + typeArgs[typeParamIndex] = actualTypeName; + } + } + + // Validate all type arguments are filled + foreach(var arg in typeArgs) + { + if(string.IsNullOrEmpty(arg)) + { + return null; // Missing type argument, cannot generate generic call + } + } + + return string.Join(", ", typeArgs); + } + + /// + /// Recursively extracts mappings from placeholder types in the template to actual types in the closed type. + /// + private static void ExtractPlaceholderMappings( + TypeData template, + TypeData closed, + Dictionary placeholderToActualType) + { + if(template is not GenericTypeData genericTemplate || closed is not GenericTypeData genericClosed) + { + return; + } + + // If base type names don't match, can't extract mappings + if(genericTemplate.NameWithoutGeneric != genericClosed.NameWithoutGeneric) + { + return; + } + + var templateParams = genericTemplate.TypeParameters; + var closedParams = genericClosed.TypeParameters; + + if(templateParams is null || closedParams is null) + { + return; + } + + if(templateParams.Length != closedParams.Length) + { + return; + } + + for(int i = 0; i < templateParams.Length; i++) + { + var templateParamType = templateParams[i].Type; + var closedParamType = closedParams[i].Type; + + // If the template param is a simple type (no nested type parameters), + // it's a placeholder that maps to the closed param type + if(templateParamType is not GenericTypeData { TypeParameters.Length: > 0 }) + { + // Map the template type name to the closed type name + // e.g., "global::System.Int32" -> "global::TestNamespace.Entity" + placeholderToActualType[templateParamType.Name] = closedParamType.Name; + } + else + { + // Nested generic type, recurse + ExtractPlaceholderMappings(templateParamType, closedParamType, placeholderToActualType); + } + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FeatureFilterHelper.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/FeatureFilterHelper.cs similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/FeatureFilterHelper.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/FeatureFilterHelper.cs diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/SourceWriterExtensions.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/SourceWriterExtensions.cs new file mode 100644 index 0000000..72e3380 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Emit/Shared/SourceWriterExtensions.cs @@ -0,0 +1,31 @@ +using static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +internal static class SourceWriterExtensions +{ + extension(SourceWriter writer) + { + public void WriteServiceLambdaOpen(string lifetime, string serviceTypeName, string? registrationKey) + { + if(registrationKey is not null) + { + writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}>({registrationKey}, ({IServiceProviderGlobalTypeName} sp, object? key) =>"); + return; + } + + writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) =>"); + } + + public void WriteEarlyReturnIfNotNull(string fieldName) + { + writer.WriteLine($"if({fieldName} is not null) return {fieldName};"); + } + + public void WriteFieldAssignAndReturn(string fieldName, string instanceVar) + { + writer.WriteLine($"{fieldName} = {instanceVar};"); + writer.WriteLine($"return {instanceVar};"); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs deleted file mode 100644 index a319f60..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateContainerOutput.cs +++ /dev/null @@ -1,3287 +0,0 @@ -using static SourceGen.Ioc.SourceGenerator.Models.Constants; - -namespace SourceGen.Ioc; - -partial class IocSourceGenerator -{ - /// - /// Generates the container source code output. - /// - private static void GenerateContainerOutput( - in SourceProductionContext ctx, - ContainerWithGroups containerWithGroups, - string assemblyName, - MsBuildProperties msbuildProps, - bool hasDIPackage) - { - if((msbuildProps.Features & IocFeatures.Container) == 0) - return; - - var filteredContainerWithGroups = FilterContainerWithGroupsForFeatures(containerWithGroups, msbuildProps.Features); - var source = GenerateContainerSource(filteredContainerWithGroups, assemblyName, msbuildProps, hasDIPackage); - var fileName = $"{containerWithGroups.Container.ClassName}.Container.g.cs"; - ctx.AddSource(fileName, source); - } - - /// - /// Generates the container source code. - /// - private static string GenerateContainerSource( - ContainerWithGroups containerWithGroups, - string assemblyName, - MsBuildProperties msbuildProps, - bool hasDIPackage) - { - var writer = new SourceWriter(); - var container = containerWithGroups.Container; - var groups = containerWithGroups.Groups; - - // Determine if IServiceProviderFactory should be generated - // Only generate if IntegrateServiceProvider is true AND the DI package is referenced - var canGenerateServiceProviderFactory = container.IntegrateServiceProvider && hasDIPackage; - - // Effective UseSwitchStatement: when there are imported modules, always use FrozenDictionary - // because combining services from multiple sources requires dictionary-based lookup - var effectiveUseSwitchStatement = container.UseSwitchStatement && container.ImportedModules.Length == 0; - - WriteContainerHeader(writer); - - // Write [assembly: MetadataUpdateHandler] attribute for hot reload cache invalidation - if(container.ImplementComponentActivator || container.ImplementComponentPropertyActivator) - { - writer.WriteLine($"[assembly: global::System.Reflection.Metadata.MetadataUpdateHandler(typeof({container.ContainerTypeName}.__HotReloadHandler))]"); - writer.WriteLine(); - } - - WriteContainerNamespaceAndClass(writer, container, groups, canGenerateServiceProviderFactory, effectiveUseSwitchStatement); - - return writer.ToString(); - } - - private static ContainerWithGroups FilterContainerWithGroupsForFeatures(ContainerWithGroups containerWithGroups, IocFeatures features) - { - if(IocFeaturesHelper.HasAllInjectionFeatures(features)) - return containerWithGroups; - - var groups = containerWithGroups.Groups; - - var byServiceTypeAndKey = new Dictionary<(string ServiceType, string? Key), ImmutableEquatableArray>(); - foreach(var kvp in groups.ByServiceTypeAndKey) - { - byServiceTypeAndKey[kvp.Key] = FilterCachedRegistrations(kvp.Value, features); - } - - var filteredByServiceTypeAndKey = byServiceTypeAndKey.ToImmutableEquatableDictionary(); - - var filteredSingletons = FilterCachedRegistrations(groups.Singletons, features); - var filteredScoped = FilterCachedRegistrations(groups.Scoped, features); - var filteredTransients = FilterCachedRegistrations(groups.Transients, features); - - var collectionRegistrations = new Dictionary>(); - foreach(var kvp in groups.CollectionRegistrations) - { - collectionRegistrations[kvp.Key] = FilterCachedRegistrations(kvp.Value, features); - } - - var filteredGroups = groups with - { - ByServiceTypeAndKey = filteredByServiceTypeAndKey, - Singletons = filteredSingletons, - Scoped = filteredScoped, - Transients = filteredTransients, - CollectionRegistrations = collectionRegistrations.ToImmutableEquatableDictionary(), - ReversedSingletonsForDisposal = FilterCachedRegistrations(groups.ReversedSingletonsForDisposal, features), - ReversedScopedForDisposal = FilterCachedRegistrations(groups.ReversedScopedForDisposal, features), - EagerSingletons = filteredSingletons.Where(static c => c.IsEager).ToImmutableEquatableArray(), - EagerScoped = filteredScoped.Where(static c => c.IsEager).ToImmutableEquatableArray(), - LazyEntries = CollectContainerLazyEntries(filteredSingletons, filteredScoped, filteredTransients, filteredByServiceTypeAndKey), - FuncEntries = CollectContainerFuncEntries(filteredSingletons, filteredScoped, filteredTransients, filteredByServiceTypeAndKey), - KvpEntries = CollectContainerKvpEntries(filteredSingletons, filteredScoped, filteredTransients, filteredByServiceTypeAndKey) - }; - - return containerWithGroups with { Groups = filteredGroups }; - } - - private static ImmutableEquatableArray FilterCachedRegistrations( - ImmutableEquatableArray registrations, - IocFeatures features) - { - if(registrations.Length == 0) - return registrations; - - List? filteredRegistrations = null; - for(var i = 0; i < registrations.Length; i++) - { - var registration = registrations[i]; - var filteredRegistration = FilterRegistrationForFeatures(registration.Registration, features); - if(ReferenceEquals(filteredRegistration, registration.Registration)) - { - if(filteredRegistrations is not null) - filteredRegistrations.Add(registration); - - continue; - } - - filteredRegistrations ??= new List(registrations.Length); - if(filteredRegistrations.Count == 0) - { - for(var j = 0; j < i; j++) - filteredRegistrations.Add(registrations[j]); - } - - filteredRegistrations.Add(registration with { Registration = filteredRegistration, IsAsyncInit = HasAsyncInitMembers(filteredRegistration) }); - } - - return filteredRegistrations is null ? registrations : filteredRegistrations.ToImmutableEquatableArray(); - } - - /// - /// Writes the auto-generated header and using directives. - /// - private static void WriteContainerHeader(SourceWriter writer) - { - writer.WriteLine(AutoGeneratedHeader); - writer.WriteLine(NullableEnable); - writer.WriteLine("#pragma warning disable SGIOCEXP001"); - writer.WriteLine(); - writer.WriteLine("using System;"); - writer.WriteLine("using System.Collections.Frozen;"); - writer.WriteLine("using System.Collections.Generic;"); - writer.WriteLine("using System.Linq;"); - writer.WriteLine("using System.Threading;"); - writer.WriteLine("using System.Threading.Tasks;"); - writer.WriteLine("using Microsoft.Extensions.DependencyInjection;"); - writer.WriteLine("using SourceGen.Ioc;"); - writer.WriteLine(); - } - - /// - /// Writes the namespace and class declaration with all implemented interfaces. - /// - private static void WriteContainerNamespaceAndClass( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool canGenerateServiceProviderFactory, - bool effectiveUseSwitchStatement) - { - // Write namespace if not global - bool hasNamespace = !string.IsNullOrEmpty(container.ContainerNamespace); - if(hasNamespace) - { - writer.WriteLine($"namespace {container.ContainerNamespace};"); - writer.WriteLine(); - } - - // Get interface list - var interfaces = GetContainerInterfaces(container, canGenerateServiceProviderFactory); - - // Write class declaration - writer.WriteLine($"partial class {container.ClassName} : {string.Join(", ", interfaces)}"); - writer.WriteLine("{"); - writer.Indentation++; - - // Write fields - WriteContainerFields(writer, container); - - // Write constructors - WriteContainerConstructors(writer, container, groups, effectiveUseSwitchStatement); - - // Write service resolver methods - WriteServiceResolverMethods(writer, container.ThreadSafeStrategy, groups); - - // Write partial accessor implementations (user-declared partial methods/properties) - WritePartialAccessorImplementations(writer, container, groups); - - // Write IServiceProvider implementation - WriteIServiceProviderImplementation(writer, container, groups, effectiveUseSwitchStatement); - - // Write IKeyedServiceProvider implementation - WriteIKeyedServiceProviderImplementation(writer, container, groups, effectiveUseSwitchStatement); - - // Write ISupportRequiredService implementation - WriteISupportRequiredServiceImplementation(writer, container); - - // Write ServiceProvider extension methods (generic overloads) - WriteServiceProviderExtensions(writer, container, effectiveUseSwitchStatement); - - // Write IServiceProviderIsService implementation - WriteIServiceProviderIsServiceImplementation(writer, container, groups, effectiveUseSwitchStatement); - - // Write IServiceScopeFactory implementation - WriteIServiceScopeFactoryImplementation(writer, container); - - // Write IIocContainer implementation - WriteIIocContainerImplementation(writer, container, groups, effectiveUseSwitchStatement); - - // Write IServiceProviderFactory implementation (if DI package is available) - if(canGenerateServiceProviderFactory) - { - WriteIServiceProviderFactoryImplementation(writer, container); - } - - // Write Disposal implementation - WriteDisposalImplementation(writer, container, groups); - - // Write IControllerActivator implementation (if container declares the interface) - if(container.ImplementControllerActivator) - { - WriteIControllerActivatorImplementation(writer, container); - } - - // Write IComponentActivator implementation (if container declares the interface) - if(container.ImplementComponentActivator) - { - WriteIComponentActivatorImplementation(writer, container); - } - - // Write IComponentPropertyActivator implementation (if container declares the interface) - if(container.ImplementComponentPropertyActivator) - { - WriteIComponentPropertyActivatorImplementation(writer, container, effectiveUseSwitchStatement); - } - - // Write hot reload handler (if either component activator interface is implemented) - if(container.ImplementComponentActivator || container.ImplementComponentPropertyActivator) - { - WriteHotReloadHandler(writer, container); - } - - writer.Indentation--; - writer.WriteLine("}"); - } - - private static readonly string[] _FixedContainerInterfaces = [ - "IServiceProvider", - "IKeyedServiceProvider", - "IServiceProviderIsService", - "IServiceProviderIsKeyedService", - "ISupportRequiredService", - "IServiceScopeFactory", - "IServiceScope", - "IDisposable", - "IAsyncDisposable" - ]; - /// - /// Builds the list of interfaces the container should implement. - /// - private static IEnumerable GetContainerInterfaces(ContainerModel container, bool canGenerateServiceProviderFactory) - { - yield return $"IIocContainer<{container.ContainerTypeName}>"; - - foreach(var i in _FixedContainerInterfaces) - yield return i; - - if(canGenerateServiceProviderFactory) - yield return "IServiceProviderFactory"; - } - - /// - /// Writes container fields (fallback provider, service storage, locks, etc.). - /// - private static void WriteContainerFields( - SourceWriter writer, - ContainerModel container) - { - // Fallback provider field (only if IntegrateServiceProvider is enabled) - if(container.IntegrateServiceProvider) - { - writer.WriteLine("private readonly IServiceProvider? _fallbackProvider;"); - } - - // Scope tracking - writer.WriteLine("private readonly bool _isRootScope = true;"); - writer.WriteLine("private int _disposed;"); - writer.WriteLine(); - - // Imported module fields - foreach(var module in container.ImportedModules) - { - var fieldName = GetModuleFieldName(module.Name); - writer.WriteLine($"private readonly {module.Name} {fieldName};"); - } - - if(container.ImportedModules.Length > 0) - { - writer.WriteLine(); - } - } - - /// - /// Writes a service instance field and synchronization field based on ThreadSafeStrategy. - /// For eager services, fields are non-nullable and no synchronization is needed. - /// - private static void WriteServiceInstanceField( - SourceWriter writer, - ThreadSafeStrategy strategy, - ServiceRegistrationModel reg, - string fieldName, - bool hasDecorators, - bool isEager) - { - // When there are decorators, field type is ServiceType (interface), otherwise ImplementationType - var typeName = hasDecorators ? reg.ServiceType.Name : reg.ImplementationType.Name; - - if(isEager) - { - // Eager services use non-nullable fields (initialized in constructor) - // Use null! to suppress CS8618 warning - field will be initialized in constructor - writer.WriteLine($"private {typeName} {fieldName} = null!;"); - // No synchronization field needed for eager services - return; - } - - // Lazy services use nullable fields - writer.WriteLine($"private {typeName}? {fieldName};"); - - // Generate synchronization field based on strategy - switch(strategy) - { - case ThreadSafeStrategy.None: - // No synchronization field needed - break; - - case ThreadSafeStrategy.Lock: - writer.WriteLine($"private readonly Lock {fieldName}Lock = new();"); - break; - - case ThreadSafeStrategy.SemaphoreSlim: - writer.WriteLine($"private readonly SemaphoreSlim {fieldName}Semaphore = new(1, 1);"); - break; - - case ThreadSafeStrategy.SpinLock: - // SpinLock must NOT be readonly because Enter/Exit mutate it - writer.WriteLine($"private SpinLock {fieldName}SpinLock = new(false);"); - break; - - case ThreadSafeStrategy.CompareExchange: - // No synchronization field needed - uses Interlocked.CompareExchange - break; - } - } - - /// - /// Writes container constructors. - /// - private static void WriteContainerConstructors( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region Constructors"); - writer.WriteLine(); - - // Default constructor - writer.WriteLine("/// "); - writer.WriteLine("/// Creates a new standalone container without external service provider fallback."); - writer.WriteLine("/// "); - - // Need fallback provider constructor if IntegrateServiceProvider is enabled - var needsFallbackProvider = container.IntegrateServiceProvider; - - if(needsFallbackProvider) - { - writer.WriteLine($"public {container.ClassName}() : this((IServiceProvider?)null) {{ }}"); - } - else - { - // Standalone mode - no fallback provider - writer.WriteLine($"public {container.ClassName}()"); - writer.WriteLine("{"); - writer.Indentation++; - WriteConstructorBody(writer, container, groups, hasParameter: false, effectiveUseSwitchStatement); - writer.Indentation--; - writer.WriteLine("}"); - } - writer.WriteLine(); - - // Constructor with fallback provider (if enabled) - if(needsFallbackProvider) - { - writer.WriteLine("/// "); - writer.WriteLine("/// Creates a new container with optional fallback to external service provider."); - writer.WriteLine("/// "); - writer.WriteLine("/// Optional external service provider for unknown dependencies."); - writer.WriteLine($"public {container.ClassName}(IServiceProvider? fallbackProvider)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("_fallbackProvider = fallbackProvider;"); - WriteConstructorBody(writer, container, groups, hasParameter: true, effectiveUseSwitchStatement); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - } - - // Private constructor for scoped instances - writer.WriteLine($"private {container.ClassName}({container.ClassName} parent)"); - writer.WriteLine("{"); - writer.Indentation++; - if(needsFallbackProvider) - { - writer.WriteLine("_fallbackProvider = parent._fallbackProvider;"); - } - writer.WriteLine("_isRootScope = false;"); - - // Copy singleton references from parent (already filtered for non-open-generics) - // Skip instance registrations as they don't have fields - foreach(var cached in groups.Singletons) - { - // Instance registrations don't have fields, skip them - if(cached.Registration.Instance is not null) - continue; - - writer.WriteLine($"{cached.FieldName} = parent.{cached.FieldName};"); - } - - // Create scopes for imported modules (so their scoped services are properly isolated) - foreach(var module in container.ImportedModules) - { - var fieldName = GetModuleFieldName(module.Name); - writer.WriteLine($"{fieldName} = ({module.Name})parent.{fieldName}.CreateScope().ServiceProvider;"); - } - - // Initialize eager scoped services by calling their Get methods - var eagerScoped = groups.EagerScoped; - if(eagerScoped.Length > 0) - { - writer.WriteLine(); - writer.WriteLine("// Initialize eager scoped services"); - foreach(var cached in eagerScoped) - { - writer.WriteLine($"{cached.FieldName} = {cached.ResolverMethodName}();"); - } - } - - // Initialize Lazy/Func wrapper fields (each scope gets its own wrappers) - WriteContainerLazyFieldInitializations(writer, groups.LazyEntries); - WriteContainerFuncFieldInitializations(writer, groups.FuncEntries); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes the constructor body for building the service resolver dictionary. - /// - private static void WriteConstructorBody( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool hasParameter, - bool effectiveUseSwitchStatement) - { - // Initialize imported modules - foreach(var module in container.ImportedModules) - { - var fieldName = GetModuleFieldName(module.Name); - if(container.IntegrateServiceProvider && hasParameter) - { - writer.WriteLine($"{fieldName} = new {module.Name}(fallbackProvider);"); - } - else - { - writer.WriteLine($"{fieldName} = new {module.Name}();"); - } - } - - // Initialize eager singletons by calling their Get methods - // This ensures dependencies are resolved in the correct order - var eagerSingletons = groups.EagerSingletons; - if(eagerSingletons.Length > 0) - { - writer.WriteLine(); - writer.WriteLine("// Initialize eager singletons"); - foreach(var cached in eagerSingletons) - { - writer.WriteLine($"{cached.FieldName} = {cached.ResolverMethodName}();"); - } - } - - // Initialize Lazy/Func wrapper fields - WriteContainerLazyFieldInitializations(writer, groups.LazyEntries); - WriteContainerFuncFieldInitializations(writer, groups.FuncEntries); - } - - /// - /// Writes individual service resolver methods. - /// - private static void WriteServiceResolverMethods( - SourceWriter writer, - ThreadSafeStrategy strategy, - ContainerRegistrationGroups groups) - { - writer.WriteLine("#region Service Resolution"); - writer.WriteLine(); - - var writtenMethods = new HashSet(); - - // All registrations are already filtered (no open generics) - WriteServiceResolverGroup(writer, strategy, groups.Singletons, writtenMethods, groups); - WriteServiceResolverGroup(writer, strategy, groups.Scoped, writtenMethods, groups); - WriteServiceResolverGroup(writer, strategy, groups.Transients, writtenMethods, groups); - - // Write array resolver methods for IEnumerable, IReadOnlyCollection, IReadOnlyList, T[] support - foreach(var serviceType in groups.CollectionServiceTypes) - { - WriteArrayResolverMethod(writer, serviceType, groups); - writer.WriteLine(); - } - - // Write KVP resolver methods for keyed services consumed as KeyValuePair/Dictionary - WriteContainerKvpResolverMethods(writer, groups.KvpEntries); - - // Write Lazy wrapper field declarations and array resolvers - WriteContainerLazyFields(writer, groups.LazyEntries); - - // Write Func wrapper field declarations and array resolvers - WriteContainerFuncFields(writer, groups.FuncEntries); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes implementations for user-declared partial methods and partial properties - /// that serve as fast-path service accessors. - /// - private static void WritePartialAccessorImplementations( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups) - { - if(container.PartialAccessors.Length == 0) - return; - - writer.WriteLine("#region Partial Accessor Implementations"); - writer.WriteLine(); - - foreach(var accessor in container.PartialAccessors) - { - var isTaskReturn = accessor.Kind == PartialAccessorKind.Method - && TryExtractTaskInnerType(accessor.ReturnTypeName, out _); - - var resolveExpression = ResolvePartialAccessorExpression(accessor, container, groups); - - switch(accessor.Kind) - { - case PartialAccessorKind.Method: - // Async partial methods (returning Task) require the 'async' modifier - if(isTaskReturn) - { - writer.WriteLine($"public partial async {accessor.ReturnTypeName} {accessor.Name}() => {resolveExpression};"); - } - else - { - writer.WriteLine($"public partial {accessor.ReturnTypeName}{(accessor.IsNullable ? "?" : "")} {accessor.Name}() => {resolveExpression};"); - } - break; - - case PartialAccessorKind.Property: - writer.WriteLine($"public partial {accessor.ReturnTypeName}{(accessor.IsNullable ? "?" : "")} {accessor.Name} {{ get => {resolveExpression}; }}"); - break; - } - - writer.WriteLine(); - } - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Resolves the expression to use for a partial accessor implementation. - /// Looks up the registration by return type and optional key, with fallback to IServiceProvider. - /// For Task<T> return types, routes through the async resolver. - /// - private static string ResolvePartialAccessorExpression( - PartialAccessorData accessor, - ContainerModel container, - ContainerRegistrationGroups groups) - { - var serviceType = accessor.ReturnTypeName; - var key = accessor.Key; - - // Handle Task return types β€” route through the async resolver. - if(TryExtractTaskInnerType(serviceType, out var innerTypeName)) - { - if(groups.ByServiceTypeAndKey.TryGetValue((innerTypeName, key), out var taskRegistrations)) - { - var cached = taskRegistrations[^1]; // Last registration wins - - if(cached.IsAsyncInit) - { - // Async-init: await the shared async resolver and let the async method wrap the cast. - var asyncMethodName = GetAsyncResolverMethodName(cached.ResolverMethodName); - return $"await {asyncMethodName}()"; - } - else - { - // Sync-only service wrapped as Task: use Task.FromResult with cast. - return $"global::System.Threading.Tasks.Task.FromResult(({innerTypeName}){cached.ResolverMethodName}())"; - } - } - - // Fallback: delegate to IServiceProvider if available - if(container.IntegrateServiceProvider) - { - if(key is not null) - return $"({serviceType})GetRequiredKeyedService(typeof({serviceType}), {key})"; - return $"({serviceType})GetRequiredService(typeof({serviceType}))"; - } - - return $"""throw new global::System.InvalidOperationException("Service '{innerTypeName}' is not registered.")"""; - } - - // Try to find direct resolver in this container - if(groups.ByServiceTypeAndKey.TryGetValue((serviceType, key), out var registrations)) - { - var cached = registrations[^1]; // Last registration wins - - if(cached.Registration.Instance is not null) - { - // Instance registration: use the instance expression directly - return cached.Registration.Instance; - } - - return $"{cached.ResolverMethodName}()"; - } - - // Fallback to GetService/GetRequiredService (only when IntegrateServiceProvider is enabled) - if(container.IntegrateServiceProvider) - { - if(key is not null) - { - return accessor.IsNullable - ? $"GetKeyedService(typeof({serviceType}), {key}) as {serviceType}" - : $"({serviceType})GetRequiredKeyedService(typeof({serviceType}), {key})"; - } - - return accessor.IsNullable - ? $"GetService(typeof({serviceType})) as {serviceType}" - : $"({serviceType})GetRequiredService(typeof({serviceType}))"; - } - - // No resolver found and no fallback: throw (analyzer should have caught this) - return accessor.IsNullable - ? "default" - : $"""throw new global::System.InvalidOperationException("Service '{serviceType}' is not registered.")"""; - } - - /// - /// Writes resolver methods for a group of registrations, ensuring unique method names. - /// - private static void WriteServiceResolverGroup( - SourceWriter writer, - ThreadSafeStrategy strategy, - IEnumerable registrations, - HashSet writtenMethods, - ContainerRegistrationGroups groups) - { - foreach(var cached in registrations) - { - if(!writtenMethods.Add(cached.ResolverMethodName)) - continue; - WriteServiceResolverMethod(writer, strategy, cached, groups); - writer.WriteLine(); - } - } - - /// - /// Writes an array resolver method for IEnumerable<T>, IReadOnlyCollection<T>, IReadOnlyList<T>, T[] resolution. - /// - private static void WriteArrayResolverMethod( - SourceWriter writer, - string serviceType, - ContainerRegistrationGroups groups) - { - var methodName = GetArrayResolverMethodName(serviceType); - var returnType = $"{serviceType}[]"; - - // Use pre-computed collection registrations (already filtered for non-open-generics) - if(!groups.CollectionRegistrations.TryGetValue(serviceType, out var registrations)) - return; - - // Deduplicate by resolver method name (or instance expression for instance registrations) - var uniqueKeys = new HashSet(); - var resolverEntries = new List(); - - foreach(var cached in registrations) - { - // Use instance expression as key for instance registrations, otherwise use method name - var key = cached.Registration.Instance ?? cached.ResolverMethodName; - if(uniqueKeys.Add(key)) - { - resolverEntries.Add(cached); - } - } - - // Only generate collection if there are multiple unique entries - if(resolverEntries.Count < 2) - return; - - writer.WriteLine($"private {returnType} {methodName}() =>"); - writer.Indentation++; - writer.WriteLine("["); - writer.Indentation++; - - foreach(var cached in resolverEntries) - { - if(cached.Registration.Instance is not null) - { - // Instance registration: directly use the instance expression - writer.WriteLine($"{cached.Registration.Instance},"); - } - else - { - // Regular registration: call the resolver method - writer.WriteLine($"{cached.ResolverMethodName}(),"); - } - } - - writer.Indentation--; - writer.WriteLine("];"); - writer.Indentation--; - } - - /// - /// Writes a single service resolver method. - /// - private static void WriteServiceResolverMethod( - SourceWriter writer, - ThreadSafeStrategy strategy, - CachedRegistration cached, - ContainerRegistrationGroups groups) - { - var reg = cached.Registration; - var methodName = cached.ResolverMethodName; - var fieldName = cached.FieldName; - var isEager = cached.IsEager; - - // Check if factory or instance registration - bool hasFactory = reg.Factory is not null; - bool hasInstance = reg.Instance is not null; - bool hasDecorators = reg.Decorators.Length > 0; - - // Return type: if there are decorators, return the ServiceType (interface), otherwise ImplementationType - var returnType = hasDecorators ? reg.ServiceType.Name : reg.ImplementationType.Name; - - // Instance registration: no resolver method needed, will be inlined in _localResolvers - if(hasInstance) - { - return; - } - - // For async-init services: generate async resolver instead of sync resolver - if(cached.IsAsyncInit) - { - switch(reg.Lifetime) - { - case ServiceLifetime.Singleton: - case ServiceLifetime.Scoped: - WriteAsyncServiceResolverMethod(writer, strategy, methodName, returnType, fieldName, reg, hasFactory, hasDecorators, groups); - break; - - case ServiceLifetime.Transient: - WriteAsyncTransientResolverMethod(writer, methodName, returnType, reg, hasFactory, hasDecorators, groups); - break; - } - return; - } - - switch(reg.Lifetime) - { - case ServiceLifetime.Singleton: - case ServiceLifetime.Scoped: - // Write field and synchronization field above the resolver method for better readability - WriteServiceInstanceField(writer, strategy, reg, fieldName, hasDecorators, isEager); - - if(isEager) - { - // Eager services: Get method still creates instance on first call (from constructor) - // but no synchronization needed since constructor runs single-threaded - WriteEagerResolverMethod(writer, methodName, returnType, fieldName, reg, hasFactory, hasDecorators, groups); - } - else - { - // Lazy services: Write resolver method based on thread-safe strategy - WriteResolverMethodWithThreadSafety(writer, strategy, methodName, returnType, fieldName, reg, hasFactory, hasDecorators, groups); - } - break; - - case ServiceLifetime.Transient: - WriteTransientResolverMethod(writer, methodName, returnType, reg, hasFactory, hasDecorators, groups); - break; - } - } - - /// - /// Writes resolver method for eager services. - /// Eager services are initialized in the constructor, so no synchronization is needed. - /// The Get method still handles first-call initialization for dependency resolution. - /// - private static void WriteEagerResolverMethod( - SourceWriter writer, - string methodName, - string returnType, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - writer.WriteLine($"private {returnType} {methodName}()"); - writer.WriteLine("{"); - writer.Indentation++; - - // For non-nullable reference types, we use object comparison instead of pattern matching - // since the field is initialized to null! and will be set during constructor - writer.WriteLine($"if({fieldName} is not null) return {fieldName};"); - writer.WriteLine(); - - // No synchronization needed - constructor runs single-threaded - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine(); - writer.WriteLine($"{fieldName} = instance;"); - writer.WriteLine("return instance;"); - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes resolver method based on the specified thread-safety strategy. - /// - private static void WriteResolverMethodWithThreadSafety( - SourceWriter writer, - ThreadSafeStrategy strategy, - string methodName, - string returnType, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - writer.WriteLine($"private {returnType} {methodName}()"); - writer.WriteLine("{"); - writer.Indentation++; - - // Early return check (common to all strategies) - writer.WriteLine($"if({fieldName} is not null) return {fieldName};"); - writer.WriteLine(); - - switch(strategy) - { - case ThreadSafeStrategy.None: - WriteResolverBodyNone(writer, fieldName, reg, hasFactory, hasDecorators, groups); - break; - - case ThreadSafeStrategy.Lock: - WriteResolverBodyLock(writer, fieldName, reg, hasFactory, hasDecorators, groups); - break; - - case ThreadSafeStrategy.SemaphoreSlim: - WriteResolverBodySemaphoreSlim(writer, fieldName, reg, hasFactory, hasDecorators, groups); - break; - - case ThreadSafeStrategy.SpinLock: - WriteResolverBodySpinLock(writer, fieldName, reg, hasFactory, hasDecorators, groups); - break; - - case ThreadSafeStrategy.CompareExchange: - WriteResolverBodyCompareExchange(writer, fieldName, reg, hasFactory, hasDecorators, groups); - break; - } - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes resolver body for ThreadSafeStrategy.None (no synchronization). - /// - private static void WriteResolverBodyNone( - SourceWriter writer, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine(); - writer.WriteLine($"{fieldName} = instance;"); - writer.WriteLine("return instance;"); - } - - /// - /// Writes resolver body for ThreadSafeStrategy.Lock (using lock statement). - /// - private static void WriteResolverBodyLock( - SourceWriter writer, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - writer.WriteLine($"lock({fieldName}Lock)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine($"if({fieldName} is not null) return {fieldName};"); - writer.WriteLine(); - - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine(); - writer.WriteLine($"{fieldName} = instance;"); - writer.WriteLine("return instance;"); - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes resolver body for ThreadSafeStrategy.SemaphoreSlim. - /// - private static void WriteResolverBodySemaphoreSlim( - SourceWriter writer, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - writer.WriteLine($"{fieldName}Semaphore.Wait();"); - writer.WriteLine("try"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine($"if({fieldName} is not null) return {fieldName};"); - writer.WriteLine(); - - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine(); - writer.WriteLine($"{fieldName} = instance;"); - writer.WriteLine("return instance;"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("finally"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine($"{fieldName}Semaphore.Release();"); - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes resolver body for ThreadSafeStrategy.SpinLock. - /// - private static void WriteResolverBodySpinLock( - SourceWriter writer, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - writer.WriteLine("bool lockTaken = false;"); - writer.WriteLine("try"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine($"{fieldName}SpinLock.Enter(ref lockTaken);"); - writer.WriteLine($"if({fieldName} is not null) return {fieldName};"); - writer.WriteLine(); - - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine(); - writer.WriteLine($"{fieldName} = instance;"); - writer.WriteLine("return instance;"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("finally"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(lockTaken) {fieldName}SpinLock.Exit();".Replace("{fieldName}", fieldName)); - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes resolver body for ThreadSafeStrategy.CompareExchange (lock-free CAS pattern). - /// Uses Interlocked.CompareExchange to atomically set the field. If another thread wins the race, - /// the losing instance is disposed via DisposeService and the winning instance is returned. - /// - private static void WriteResolverBodyCompareExchange( - SourceWriter writer, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine(); - writer.WriteLine($"var existing = Interlocked.CompareExchange(ref {fieldName}, instance, null);"); - writer.WriteLine("if(existing is not null)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("DisposeService(instance);"); - writer.WriteLine("return existing;"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("return instance;"); - } - - /// - /// Writes transient resolver method (creates new instance each time). - /// - private static void WriteTransientResolverMethod( - SourceWriter writer, - string methodName, - string returnType, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - writer.WriteLine($"private {returnType} {methodName}()"); - writer.WriteLine("{"); - writer.Indentation++; - - // Simple case: no injection members and no decorators - if(reg.InjectionMembers.Length == 0 && !hasDecorators) - { - var instanceCreation = BuildInstanceCreationInline(reg, hasFactory, groups); - writer.WriteLine($"return {instanceCreation};"); - writer.Indentation--; - writer.WriteLine("}"); - return; - } - - // Complex case: has injection members or decorators - var variableType = hasDecorators ? reg.ServiceType.Name : null; - WriteInstanceCreationWithInjection(writer, "instance", reg, hasFactory, variableType, groups); - - if(hasDecorators) - { - writer.WriteLine(); - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine("return instance;"); - writer.Indentation--; - writer.WriteLine("}"); - } - - // ────────────────────────────────────────────────────────────────────────────── - // Async-init service resolver generation - // Async-init services have at least one InjectionMemberType.AsyncMethod member. - // They use Task caching and async resolver methods. - // ────────────────────────────────────────────────────────────────────────────── - - /// - /// Returns the async routing resolver method name by appending "Async" to the sync method name. - /// - private static string GetAsyncResolverMethodName(string syncMethodName) - => syncMethodName + "Async"; - - /// - /// Returns the async creation method name (e.g. "CreateFooBarAsync" from "GetFooBar"). - /// - private static string GetAsyncCreateMethodName(string syncMethodName) - { - if(syncMethodName.Length > 3 && syncMethodName.StartsWith("Get", StringComparison.Ordinal)) - return "Create" + syncMethodName[3..] + "Async"; - return syncMethodName + "_CreateAsync"; - } - - /// - /// Returns the effective thread-safety strategy for a registration. - /// Async-init services auto-upgrade async-incompatible strategies to . - /// - private static ThreadSafeStrategy GetEffectiveThreadSafeStrategy( - ThreadSafeStrategy strategy, - bool isAsyncInit) - { - if(!isAsyncInit) - return strategy; - - return strategy is ThreadSafeStrategy.None ? ThreadSafeStrategy.None : ThreadSafeStrategy.SemaphoreSlim; - } - - /// - /// Writes the field declaration for an async-init service's cached Task<T>. - /// The caller must pass the effective async-init strategy. - /// - private static void WriteAsyncServiceInstanceField( - SourceWriter writer, - ThreadSafeStrategy strategy, - string fieldName, - string taskFieldTypeName) - { - writer.WriteLine($"private {taskFieldTypeName}? {fieldName};"); - - // Only SemaphoreSlim is async-compatible; others fall back to unsynchronized access. - if(strategy == ThreadSafeStrategy.SemaphoreSlim) - { - writer.WriteLine($"private readonly global::System.Threading.SemaphoreSlim {fieldName}Semaphore = new(1, 1);"); - } - } - - /// - /// Writes an async routing resolver + async creation method for a singleton/scoped async-init service. - /// - private static void WriteAsyncServiceResolverMethod( - SourceWriter writer, - ThreadSafeStrategy strategy, - string syncMethodName, - string returnType, - string fieldName, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - var asyncMethodName = GetAsyncResolverMethodName(syncMethodName); - var createMethodName = GetAsyncCreateMethodName(syncMethodName); - var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; - var effectiveStrategy = GetEffectiveThreadSafeStrategy(strategy, true); - - // Write the Task? instance field (+ semaphore if SemaphoreSlim) - WriteAsyncServiceInstanceField(writer, effectiveStrategy, fieldName, taskReturnType); - writer.WriteLine(); - - // ── Routing resolver method ── - writer.WriteLine($"private async {taskReturnType} {asyncMethodName}()"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine($"if({fieldName} is not null)"); - writer.Indentation++; - writer.WriteLine($"return await {fieldName};"); - writer.Indentation--; - writer.WriteLine(); - - if(effectiveStrategy == ThreadSafeStrategy.SemaphoreSlim) - { - WriteAsyncResolverBodySemaphoreSlim(writer, fieldName, createMethodName); - } - else - { - WriteAsyncResolverBodyNone(writer, fieldName, createMethodName); - } - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // ── Creation method ── - writer.WriteLine($"private async {taskReturnType} {createMethodName}()"); - writer.WriteLine("{"); - writer.Indentation++; - - WriteAsyncInstanceCreationBody(writer, reg, hasFactory, hasDecorators, groups); - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes an async creation method for a transient async-init service. - /// Each call produces a new Task (no caching). - /// - private static void WriteAsyncTransientResolverMethod( - SourceWriter writer, - string syncMethodName, - string returnType, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - var createMethodName = GetAsyncCreateMethodName(syncMethodName); - var taskReturnType = $"global::System.Threading.Tasks.Task<{returnType}>"; - - writer.WriteLine($"private async {taskReturnType} {createMethodName}()"); - writer.WriteLine("{"); - writer.Indentation++; - - WriteAsyncInstanceCreationBody(writer, reg, hasFactory, hasDecorators, groups); - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes the instance creation body for an async-init service: - /// constructor, sync injection (properties + sync methods), await async methods, optional decorators. - /// - private static void WriteAsyncInstanceCreationBody( - SourceWriter writer, - ServiceRegistrationModel reg, - bool hasFactory, - bool hasDecorators, - ContainerRegistrationGroups groups) - { - var (properties, syncMethods, asyncMethods) = CategorizeInjectionMembersAsync(reg.InjectionMembers); - var args = BuildConstructorArgumentsString(reg, groups); - - // When decorators are present AND there is method injection (sync or async), we cannot - // type the instance as the service interface because [IocInject] methods may be on the - // concrete implementation only. - // - // Two-variable pattern: - // var baseInstance = new Impl(args) { Props... }; - // baseInstance.SyncMethod(...); - // await baseInstance.AsyncMethod(...); - // ServiceType instance = baseInstance; - // instance = new Decorator(instance); // decorator chain - // - // Single-variable pattern (no decorators, or decorators + pure property injection): - // var instance = new Impl(args) { Props... }; - // await instance.AsyncInit(...); - bool hasMethods = syncMethods is { Count: > 0 } || asyncMethods is { Count: > 0 }; - bool needsTwoVarPattern = hasDecorators && hasMethods; - - // ── Create the instance ── - string injectionVar = needsTwoVarPattern ? "baseInstance" : "instance"; - string varTypeDecl = (hasDecorators && !needsTwoVarPattern) ? reg.ServiceType.Name : "var"; - - if(hasFactory) - { - var factoryCall = BuildFactoryCallForContainer(reg.Factory!, reg, groups); - writer.WriteLine($"{varTypeDecl} {injectionVar} = ({reg.ImplementationType.Name}){factoryCall};"); - } - else - { - WriteConstructorWithPropertyInitializers(writer, injectionVar, varTypeDecl, reg.ImplementationType.Name, args, properties, groups); - } - - // ── Sync method injection ── - if(syncMethods is { Count: > 0 }) - { - foreach(var method in syncMethods) - { - var methodArgs = method.Parameters is { Length: > 0 } - ? string.Join(", ", method.Parameters.Select(p => BuildParameterForContainer(p, reg, groups))) - : ""; - writer.WriteLine($"{injectionVar}.{method.Name}({methodArgs});"); - } - } - - // ── Awaited async method injection ── - if(asyncMethods is { Count: > 0 }) - { - foreach(var method in asyncMethods) - { - var methodArgs = method.Parameters is { Length: > 0 } - ? string.Join(", ", method.Parameters.Select(p => BuildParameterForContainer(p, reg, groups))) - : ""; - writer.WriteLine($"await {injectionVar}.{method.Name}({methodArgs});"); - } - } - - // ── Apply decorators after all injection ── - if(hasDecorators) - { - writer.WriteLine(); - if(needsTwoVarPattern) - { - // Convert the concrete implementation variable to the service type - // so the decorator chain can reassign the variable. - writer.WriteLine($"{reg.ServiceType.Name} instance = {injectionVar};"); - } - WriteDecoratorApplication(writer, "instance", reg, groups); - } - - writer.WriteLine("return instance;"); - } - - /// - /// Writes the async routing resolver body for (no synchronization). - /// - private static void WriteAsyncResolverBodyNone( - SourceWriter writer, - string fieldName, - string createMethodName) - { - writer.WriteLine($"{fieldName} = {createMethodName}();"); - writer.WriteLine($"return await {fieldName};"); - } - - /// - /// Writes the async routing resolver body for . - /// Uses WaitAsync() for async-compatible locking. - /// - private static void WriteAsyncResolverBodySemaphoreSlim( - SourceWriter writer, - string fieldName, - string createMethodName) - { - writer.WriteLine($"await {fieldName}Semaphore.WaitAsync();"); - writer.WriteLine("try"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine($"if({fieldName} is null)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine($"{fieldName} = {createMethodName}();"); - writer.Indentation--; - writer.WriteLine("}"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("finally"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine($"{fieldName}Semaphore.Release();"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine($"return await {fieldName};"); - } - - /// - /// Categorizes injection members into properties/fields, sync methods, and async methods. - /// - private static (List? Properties, List? SyncMethods, List? AsyncMethods) CategorizeInjectionMembersAsync( - ImmutableEquatableArray injectionMembers) - { - List? properties = null; - List? syncMethods = null; - List? asyncMethods = null; - - foreach(var member in injectionMembers) - { - switch(member.MemberType) - { - case InjectionMemberType.Property: - case InjectionMemberType.Field: - properties ??= []; - properties.Add(member); - break; - case InjectionMemberType.Method: - syncMethods ??= []; - syncMethods.Add(member); - break; - case InjectionMemberType.AsyncMethod: - asyncMethods ??= []; - asyncMethods.Add(member); - break; - } - } - - return (properties, syncMethods, asyncMethods); - } - - /// - /// Tries to extract the inner type name from a - /// global::System.Threading.Tasks.Task<T> type name string. - /// Returns and sets if matched. - /// - private static bool TryExtractTaskInnerType(string typeName, out string innerTypeName) - { - const string TaskPrefix = "global::System.Threading.Tasks.Task<"; - if(typeName.StartsWith(TaskPrefix, StringComparison.Ordinal) - && typeName.EndsWith(">", StringComparison.Ordinal)) - { - innerTypeName = typeName[TaskPrefix.Length..^1]; - return true; - } - innerTypeName = string.Empty; - return false; - } - - /// - /// Writes instance creation with property/method injection. - /// - /// Optional explicit type for the variable. When null, 'var' is used. - private static void WriteInstanceCreationWithInjection( - SourceWriter writer, - string varName, - ServiceRegistrationModel reg, - bool hasFactory, - string? variableType, - ContainerRegistrationGroups groups) - { - var typeDeclaration = variableType ?? "var"; - - if(hasFactory) - { - var factoryCall = BuildFactoryCallForContainer(reg.Factory!, reg, groups); - writer.WriteLine($"{typeDeclaration} {varName} = ({reg.ImplementationType.Name}){factoryCall};"); - return; - } - - WriteConstructorWithInjection(writer, varName, typeDeclaration, reg, groups); - } - - /// - /// Writes constructor invocation with property/field and method injection. - /// - private static void WriteConstructorWithInjection( - SourceWriter writer, - string varName, - string typeDeclaration, - ServiceRegistrationModel reg, - ContainerRegistrationGroups groups) - { - var (properties, methods) = CategorizeInjectionMembers(reg.InjectionMembers); - var args = BuildConstructorArgumentsString(reg, groups); - - WriteConstructorWithPropertyInitializers(writer, varName, typeDeclaration, reg.ImplementationType.Name, args, properties, groups); - WriteMethodInjectionCalls(writer, varName, methods, reg, groups); - } - - /// - /// Categorizes injection members into properties/fields and methods. - /// - private static (List? Properties, List? Methods) CategorizeInjectionMembers( - ImmutableEquatableArray injectionMembers) - { - List? properties = null; - List? methods = null; - - foreach(var member in injectionMembers) - { - if(member.MemberType is InjectionMemberType.Property or InjectionMemberType.Field) - { - properties ??= []; - properties.Add(member); - } - else if(member.MemberType == InjectionMemberType.Method) - { - methods ??= []; - methods.Add(member); - } - } - - return (properties, methods); - } - - /// - /// Writes constructor invocation with optional property initializers. - /// - private static void WriteConstructorWithPropertyInitializers( - SourceWriter writer, - string varName, - string typeDeclaration, - string typeName, - string args, - List? properties, - ContainerRegistrationGroups groups) - { - if(properties is not { Count: > 0 }) - { - writer.WriteLine($"{typeDeclaration} {varName} = new {typeName}({args});"); - return; - } - - writer.WriteLine($"{typeDeclaration} {varName} = new {typeName}({args})"); - writer.WriteLine("{"); - writer.Indentation++; - - foreach(var prop in properties) - { - var resolveCall = BuildServiceResolutionCallForContainer(prop.Type!, prop.Key, prop.IsNullable, groups); - writer.WriteLine($"{prop.Name} = {resolveCall},"); - } - - writer.Indentation--; - writer.WriteLine("};"); - } - - /// - /// Writes method injection calls. - /// - private static void WriteMethodInjectionCalls( - SourceWriter writer, - string varName, - List? methods, - ServiceRegistrationModel reg, - ContainerRegistrationGroups groups) - { - if(methods is null) - return; - - foreach(var method in methods) - { - var methodArgs = method.Parameters is { Length: > 0 } - ? string.Join(", ", method.Parameters.Select(p => BuildParameterForContainer(p, reg, groups))) - : ""; - writer.WriteLine($"{varName}.{method.Name}({methodArgs});"); - } - } - - /// - /// Writes decorator application code. - /// Decorators are applied in reverse order (from innermost to outermost), - /// matching the behavior of Register mode. - /// - private static void WriteDecoratorApplication( - SourceWriter writer, - string varName, - ServiceRegistrationModel reg, - ContainerRegistrationGroups groups) - { - // Decorators array is in order from outermost to innermost, - // we iterate in reverse order for building the chain from inner to outer - var decorators = reg.Decorators; - for(int i = decorators.Length - 1; i >= 0; i--) - { - var decorator = decorators[i]; - var hasInjectionMembers = decorator.InjectionMembers?.Length > 0; - var argsString = string.Join(", ", GetDecoratorArguments(varName, decorator, groups)); - - if(hasInjectionMembers) - { - // When decorator has injection members, use a temporary variable with concrete type - // to allow accessing the decorator's properties/methods before assigning to interface variable - var decoratorVarName = $"decorator{decorators.Length - 1 - i}"; - WriteDecoratorCreationWithInjection(writer, decoratorVarName, decorator, argsString, reg, groups); - - // Assign to the interface variable - writer.WriteLine($"{varName} = {decoratorVarName};"); - } - else - { - // No injection members, directly assign to the interface variable - writer.WriteLine($"{varName} = new {decorator.Name}({argsString});"); - } - } - } - - /// - /// Yields decorator constructor arguments without allocating a List. - /// First parameter is always the inner instance (the decorated service). - /// - private static IEnumerable GetDecoratorArguments(string innerInstance, TypeData decorator, ContainerRegistrationGroups groups) - { - yield return innerInstance; - - if(decorator.ConstructorParameters?.Length > 1) - { - // Skip the first parameter (it's the inner/decorated service) - foreach(var param in decorator.ConstructorParameters.Skip(1)) - { - yield return BuildServiceResolutionCallForContainer(param.Type, param.ServiceKey, param.IsNullable, groups); - } - } - } - - /// - /// Writes decorator creation with object initializer for property/field injection, - /// and method calls for method injection. - /// - private static void WriteDecoratorCreationWithInjection( - SourceWriter writer, - string varName, - TypeData decorator, - string argsString, - ServiceRegistrationModel reg, - ContainerRegistrationGroups groups) - { - var injectionMembers = decorator.InjectionMembers; - if(injectionMembers is null or { Length: 0 }) - { - writer.WriteLine($"var {varName} = new {decorator.Name}({argsString});"); - return; - } - - List? propertyAssignments = null; - List? methodInvocations = null; - - foreach(var member in injectionMembers) - { - switch(member.MemberType) - { - case InjectionMemberType.Property: - case InjectionMemberType.Field: - if(member.Type is not null) - { - var resolveCall = BuildServiceResolutionCallForContainer(member.Type, member.Key, member.IsNullable, groups); - propertyAssignments ??= []; - propertyAssignments.Add($"{member.Name} = {resolveCall},"); - } - break; - - case InjectionMemberType.Method: - var methodArgs = member.Parameters is { Length: > 0 } - ? string.Join(", ", member.Parameters.Select(p => BuildParameterForInjectionMethod(p, reg, groups))) - : ""; - methodInvocations ??= []; - methodInvocations.Add($"{varName}.{member.Name}({methodArgs});"); - break; - } - } - - WriteDecoratorConstructorWithProperties(writer, varName, decorator.Name, argsString, propertyAssignments); - - if(methodInvocations is not null) - { - foreach(var invocation in methodInvocations) - { - writer.WriteLine(invocation); - } - } - } - - /// - /// Writes decorator constructor with optional property initializers. - /// - private static void WriteDecoratorConstructorWithProperties( - SourceWriter writer, - string varName, - string decoratorName, - string argsString, - List? propertyAssignments) - { - if(propertyAssignments is not { Count: > 0 }) - { - writer.WriteLine($"var {varName} = new {decoratorName}({argsString});"); - return; - } - - writer.WriteLine($"var {varName} = new {decoratorName}({argsString})"); - writer.WriteLine("{"); - writer.Indentation++; - - foreach(var assignment in propertyAssignments) - { - writer.WriteLine(assignment); - } - - writer.Indentation--; - writer.WriteLine("};"); - } - - /// - /// Builds instance creation inline (for return statements). - /// - private static string BuildInstanceCreationInline( - ServiceRegistrationModel reg, - bool hasFactory, - ContainerRegistrationGroups groups) - { - if(hasFactory) - { - var factoryCall = BuildFactoryCallForContainer(reg.Factory!, reg, groups); - return $"({reg.ImplementationType.Name}){factoryCall}"; - } - - var args = BuildConstructorArgumentsString(reg, groups); - return $"new {reg.ImplementationType.Name}({args})"; - } - - /// - /// Builds constructor arguments as a string. - /// - private static string BuildConstructorArgumentsString(ServiceRegistrationModel reg, ContainerRegistrationGroups groups) - { - var parameters = reg.ImplementationType.ConstructorParameters; - if(parameters is null or { Length: 0 }) - return ""; - - return string.Join(", ", parameters.Select(p => BuildParameterForContainer(p, reg, groups))); - } - - /// - /// Builds a single parameter for injection method (handles IServiceProvider and service resolution). - /// - private static string BuildParameterForInjectionMethod(ParameterData param, ServiceRegistrationModel reg, ContainerRegistrationGroups groups) - { - if(param.Type.Name is IServiceProviderTypeName or IServiceProviderGlobalTypeName) - return "this"; - - return BuildServiceResolutionCallForContainer(param.Type, param.ServiceKey, param.IsOptional, groups); - } - - /// - /// Builds a single parameter for container (constructor, method injection, or factory). - /// Handles [ServiceKey], [FromKeyedServices], IServiceProvider, and regular service resolution. - /// - private static string BuildParameterForContainer(ParameterData param, ServiceRegistrationModel reg, ContainerRegistrationGroups groups) - { - if(param.HasServiceKeyAttribute) - return reg.Key ?? "null"; - - if(param.Type.Name is IServiceProviderTypeName or IServiceProviderGlobalTypeName) - return "this"; - - return BuildServiceResolutionCallForContainer(param.Type, param.ServiceKey, param.IsOptional, groups); - } - - /// - /// Builds a factory method call string for container. - /// - private static string BuildFactoryCallForContainer(FactoryMethodData factory, ServiceRegistrationModel reg, ContainerRegistrationGroups groups) - { - var args = GetFactoryArguments(factory, reg, groups); - var genericTypeArgs = BuildGenericFactoryTypeArgs(factory, reg.ServiceType); - var factoryCallPath = genericTypeArgs is not null ? $"{factory.Path}<{genericTypeArgs}>" : factory.Path; - - return $"{factoryCallPath}({string.Join(", ", args)})"; - } - - /// - /// Yields factory method arguments without allocating a List. - /// - private static IEnumerable GetFactoryArguments(FactoryMethodData factory, ServiceRegistrationModel reg, ContainerRegistrationGroups groups) - { - if(factory.HasServiceProvider) - yield return "this"; - - if(factory.HasKey && reg.Key is not null) - yield return reg.Key; - - foreach(var param in factory.AdditionalParameters) - yield return BuildParameterForContainer(param, reg, groups); - } - - private static string BuildServiceProviderFallbackExpression( - string typeName, - string? key, - bool isOptional) - { - if(key is not null) - return isOptional - ? $"GetKeyedService(typeof({typeName}), {key}) as {typeName}" - : $"({typeName})GetRequiredKeyedService(typeof({typeName}), {key})"; - return isOptional - ? $"GetService(typeof({typeName})) as {typeName}" - : $"({typeName})GetRequiredService(typeof({typeName}))"; - } - - /// - /// Builds a service resolution call for container (direct call or GetService/GetRequiredService). - /// When the dependency is registered in the same container, calls the resolver method directly. - /// - private static string BuildServiceResolutionCallForContainer( - TypeData type, - string? key, - bool isOptional, - ContainerRegistrationGroups groups) - { - // Collection types - use collection resolver method if available - if(type is CollectionWrapperTypeData collectionType) - { - var elementTypeName = collectionType.ElementType.Name; - - // Keyed collection services - fallback to GetKeyedServices (no direct call support yet) - if(key is not null) - { - return $"GetKeyedServices<{elementTypeName}>({key})"; - } - - // Check if element type is KeyValuePair β€” use KVP resolver if available - if(collectionType.ElementType is KeyValuePairTypeData kvpElement) - { - var kvpKeyType = kvpElement.KeyType.Name; - var kvpValueType = kvpElement.ValueType.Name; - if(HasKvpRegistrations(kvpKeyType, kvpValueType, groups)) - { - // IEnumerable, IReadOnlyCollection, ICollection β†’ Dictionary resolver - // IReadOnlyList, IList, T[] β†’ Array resolver (consistent with _localResolvers) - var isArrayType = collectionType.WrapperKind is WrapperKind.ReadOnlyList or WrapperKind.List or WrapperKind.Array; - var methodName = isArrayType - ? GetKvpArrayResolverMethodName(kvpKeyType, kvpValueType) - : GetKvpDictionaryResolverMethodName(kvpKeyType, kvpValueType); - return $"{methodName}()"; - } - } - - // Check if we have a collection resolver for this element type - if(groups.CollectionRegistrations.ContainsKey(elementTypeName)) - { - var methodName = GetArrayResolverMethodName(elementTypeName); - return $"{methodName}()"; - } - - return $"GetServices<{elementTypeName}>()"; - } - - // Wrapper types - Lazy, Func, KeyValuePair, Task - if(type is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) - { - return BuildWrapperExpressionForContainer(type, key, isOptional, groups); - } - - // Try to find direct resolver in this container - if(groups.ByServiceTypeAndKey.TryGetValue((type.Name, key), out var registrations)) - { - var cached = registrations[^1]; // Last wins - // Async-init services: the sync method was not generated; use the async method instead. - // Callers that depend on an async-init service should be taking Task, not T directly. - // The analyzer (SGIOC027/029) normally prevents this, but fall back gracefully. - if(cached.IsAsyncInit) - { - if(cached.Registration.Lifetime == ServiceLifetime.Transient) - return $"{GetAsyncCreateMethodName(cached.ResolverMethodName)}()"; - return $"{GetAsyncResolverMethodName(cached.ResolverMethodName)}()"; - } - return $"{cached.ResolverMethodName}()"; - } - - // Fallback to GetService/GetRequiredService for dependencies not in this container - return BuildServiceProviderFallbackExpression(type.Name, key, isOptional); - } - - /// - /// Builds a wrapper expression for container resolution (Lazy, Func, KeyValuePair, Dictionary). - /// Recursively handles nested wrapper types. - /// - private static string BuildWrapperExpressionForContainer( - TypeData type, - string? key, - bool isOptional, - ContainerRegistrationGroups groups, - bool useResolverMethods = true) - { - switch(type) - { - case LazyTypeData lazy: - { - var innerType = lazy.InstanceType; - // Direct Lazy where T is not a wrapper β€” call wrapper resolver if available (only at top level) - if(innerType is not WrapperTypeData && useResolverMethods) - { - if(groups.ByServiceTypeAndKey.TryGetValue((innerType.Name, key), out var innerRegs)) - { - var lastReg = innerRegs[^1]; - var safeInnerType = GetSafeIdentifier(innerType.Name); - var safeImplType = GetSafeIdentifier(lastReg.Registration.ImplementationType.Name); - return $"_lazy_{safeInnerType}_{safeImplType}"; - } - // Fallback: inner type not in this container β€” build inline via IServiceProvider - var lazyFallbackExpr = BuildServiceProviderFallbackExpression(innerType.Name, key, isOptional); - return $"new global::System.Lazy<{innerType.Name}>(() => {lazyFallbackExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; - } - // Nested wrapper or inside nested context β€” inline construction - var lazyInnerExpr = BuildInnerResolutionForContainer(innerType, key, isOptional, groups); - return $"new global::System.Lazy<{innerType.Name}>(() => {lazyInnerExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; - } - - case FuncTypeData func: - { - var innerType = func.ReturnType; - - if(func.HasInputParameters) - { - if(groups.ByServiceTypeAndKey.TryGetValue((innerType.Name, key), out var innerRegs)) - { - var targetRegistration = innerRegs[^1].Registration; - return BuildContainerMultiParamFuncExpression(func, targetRegistration, groups); - } - - // Fallback: inner return type not in this container β€” resolve the full Func<...> type - // directly from IServiceProvider. Do NOT call BuildServiceResolutionCallForContainer - // here as that would route FuncTypeData back to BuildWrapperExpressionForContainer, - // causing infinite recursion. - return BuildServiceProviderFallbackExpression(type.Name, key, isOptional); - } - - // Direct Func where T is not a wrapper β€” call wrapper resolver if available (only at top level) - if(innerType is not WrapperTypeData && useResolverMethods) - { - if(groups.ByServiceTypeAndKey.TryGetValue((innerType.Name, key), out var innerRegs)) - { - var lastReg = innerRegs[^1]; - var safeInnerType = GetSafeIdentifier(innerType.Name); - var safeImplType = GetSafeIdentifier(lastReg.Registration.ImplementationType.Name); - return $"_func_{safeInnerType}_{safeImplType}"; - } - // Fallback: inner type not in this container β€” build inline via IServiceProvider - var funcFallbackExpr = BuildServiceProviderFallbackExpression(innerType.Name, key, isOptional); - return $"new global::System.Func<{innerType.Name}>(() => {funcFallbackExpr})"; - } - // Nested wrapper or inside nested context β€” inline construction - var funcInnerExpr = BuildInnerResolutionForContainer(innerType, key, isOptional, groups); - return $"new global::System.Func<{innerType.Name}>(() => {funcInnerExpr})"; - } - - case KeyValuePairTypeData kvp: - { - var keyType = kvp.KeyType; - var valueType = kvp.ValueType; - var keyExpr = key ?? "default"; - var valueExpr = BuildInnerResolutionForContainer(valueType, key, isOptional, groups); - return $"new global::System.Collections.Generic.KeyValuePair<{keyType.Name}, {valueType.Name}>({keyExpr}, {valueExpr})"; - } - - case DictionaryTypeData dict: - { - // Dictionary resolution: use KVP dictionary resolver if available, otherwise fallback. - var keyType = dict.KeyType; - var valueType = dict.ValueType; - if(key is null && HasKvpRegistrations(keyType.Name, valueType.Name, groups)) - { - var methodName = GetKvpDictionaryResolverMethodName(keyType.Name, valueType.Name); - return $"{methodName}()"; - } - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{keyType.Name}, {valueType.Name}>"; - if(key is not null) - { - return $"GetKeyedServices<{kvpTypeName}>({key}).ToDictionary()"; - } - return $"GetServices<{kvpTypeName}>().ToDictionary()"; - } - - case TaskTypeData task: - { - // Task wrapper: route based on sync vs async-init registration. - var innerType = task.InnerType; - var innerTypeName = innerType.Name; - - if(groups.ByServiceTypeAndKey.TryGetValue((innerTypeName, key), out var innerRegs)) - { - var lastReg = innerRegs[^1]; - if(lastReg.IsAsyncInit) - { - // Async-init: project Task β†’ Task via async lambda (not ContinueWith) - // so that exceptions propagate as-awaited rather than wrapped in AggregateException. - var asyncMethodName = GetAsyncResolverMethodName(lastReg.ResolverMethodName); - return $"((global::System.Func>)(async () => ({innerTypeName})(await {asyncMethodName}())))()"; - } - else - { - // Sync-only: wrap in Task.FromResult with cast. - return $"global::System.Threading.Tasks.Task.FromResult(({innerTypeName}){lastReg.ResolverMethodName}())"; - } - } - - // Fallback to IServiceProvider - return BuildServiceProviderFallbackExpression(type.Name, key, isOptional); - } - - default: - return BuildServiceResolutionCallForContainer(type, key, isOptional, groups); - } - } - - /// - /// Builds an inner resolution expression for container β€” handles nested wrappers, nested - /// collections (via ), or direct resolution. - /// Supports nesting such as Lazy<IEnumerable<T>>. - /// - private static string BuildInnerResolutionForContainer( - TypeData innerType, - string? key, - bool isOptional, - ContainerRegistrationGroups groups) - { - if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) - { - // Inner wrappers always use inline construction (no resolver methods). This is a deliberate - // pragmatic choice: direct container resolution of nested wrappers (for example, - // container.GetService>>()) is extremely rare, while constructor injection is - // fully covered by inline construction. This keeps implementation simple by avoiding - // extensions to field scanning, naming conventions, and scoped-container infrastructure for - // every nested wrapper shape. Nested wrappers also typically do not require cross-consumer - // instance sharing, so each consumer owning its own inline-constructed instance is - // semantically correct. - // NOTE: Nested Task shapes such as Lazy> or IEnumerable> are not - // supported by the spec. The transform layer prevents these from reaching code generation - // by downgrading their WrapperKind to None, so they fall back to IServiceProvider. - return BuildWrapperExpressionForContainer(innerType, key, isOptional, groups, useResolverMethods: false); - } - - // Delegates to BuildServiceResolutionCallForContainer which handles: - // - Collection types (EnumerableTypeData, ReadOnlyCollectionTypeData, CollectionTypeData, ReadOnlyListTypeData, ListTypeData, ArrayTypeData) via GetServices/array resolvers - // - Direct service resolution via resolver methods or GetRequiredService fallback - return BuildServiceResolutionCallForContainer(innerType, key, isOptional, groups); - } - - /// - /// Builds an inline multi-parameter Func expression for container resolution. - /// Uses first-unused type matching from Func input args to constructor/method/property injection targets. - /// - private static string BuildContainerMultiParamFuncExpression( - FuncTypeData funcType, - ServiceRegistrationModel registration, - ContainerRegistrationGroups groups) - { - var inputTypes = funcType.InputTypes; - var inputArgNames = new string[inputTypes.Length]; - var inputArgTypeNames = new string[inputTypes.Length]; - var inputArgUsed = new bool[inputTypes.Length]; - - for(var i = 0; i < inputTypes.Length; i++) - { - inputArgNames[i] = $"arg{i}"; - inputArgTypeNames[i] = inputTypes[i].Type.Name; - inputArgUsed[i] = false; - } - - var lambdaParams = string.Join(", ", inputTypes.Select(static (t, i) => $"{t.Type.Name} arg{i}")); - - var statements = new List(); - var ctorParams = registration.ImplementationType.ConstructorParameters ?? []; - var ctorEntries = new List<(string Name, string? Value, bool NeedsConditional)>(ctorParams.Length); - - var resolvedParamIndex = 0; - foreach(var param in ctorParams) - { - var matchedArg = TryConsumeMatchingFuncInputArg(param.Type.Name, inputArgNames, inputArgTypeNames, inputArgUsed); - if(matchedArg is not null) - { - ctorEntries.Add((param.Name, matchedArg, false)); - continue; - } - - var paramVar = $"p{resolvedParamIndex}"; - var expr = BuildServiceResolutionCallForContainer(param.Type, param.ServiceKey, param.IsOptional, groups); - statements.Add($"var {paramVar} = {expr};"); - ctorEntries.Add((param.Name, paramVar, false)); - resolvedParamIndex++; - } - - var propertyInits = new List(); - var propertyIndex = 0; - foreach(var member in registration.InjectionMembers) - { - if(member.MemberType is not (InjectionMemberType.Property or InjectionMemberType.Field)) - continue; - - var memberType = member.Type; - if(memberType is null) - continue; - - var matchedArg = TryConsumeMatchingFuncInputArg(memberType.Name, inputArgNames, inputArgTypeNames, inputArgUsed); - if(matchedArg is not null) - { - propertyInits.Add($"{member.Name} = {matchedArg}"); - continue; - } - - var memberVar = $"s0_p{propertyIndex}"; - var expr = BuildServiceResolutionCallForContainer(memberType, member.Key, member.IsNullable, groups); - statements.Add($"var {memberVar} = {expr};"); - propertyInits.Add($"{member.Name} = {memberVar}"); - propertyIndex++; - } - - var ctorArgs = BuildArgumentListFromEntries([.. ctorEntries]); - var initializerPart = propertyInits.Count > 0 ? $" {{ {string.Join(", ", propertyInits)} }}" : string.Empty; - var ctorInvocation = BuildConstructorInvocation(registration.ImplementationType.Name, ctorArgs, initializerPart); - statements.Add($"var s0 = {ctorInvocation};"); - - var methodIndex = 0; - foreach(var method in registration.InjectionMembers) - { - if(method.MemberType != InjectionMemberType.Method) - continue; - - var methodParams = method.Parameters ?? []; - var methodEntries = new List<(string Name, string? Value, bool NeedsConditional)>(methodParams.Length); - foreach(var param in methodParams) - { - var matchedArg = TryConsumeMatchingFuncInputArg(param.Type.Name, inputArgNames, inputArgTypeNames, inputArgUsed); - if(matchedArg is not null) - { - methodEntries.Add((param.Name, matchedArg, false)); - continue; - } - - var paramVar = $"s0_m{methodIndex}"; - var expr = BuildServiceResolutionCallForContainer(param.Type, method.Key ?? param.ServiceKey, param.IsOptional, groups); - statements.Add($"var {paramVar} = {expr};"); - methodEntries.Add((param.Name, paramVar, false)); - methodIndex++; - } - - var methodArgs = BuildArgumentListFromEntries([.. methodEntries]); - statements.Add($"s0.{method.Name}({methodArgs});"); - } - - statements.Add("return s0;"); - - return $"new {funcType.Name}(({lambdaParams}) => {{ {string.Join(" ", statements)} }})"; - } - - /// - /// Writes IServiceProvider implementation. - /// - private static void WriteIServiceProviderImplementation( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region IServiceProvider"); - writer.WriteLine(); - - writer.WriteLine("public object? GetService(Type serviceType)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine(); - - if(effectiveUseSwitchStatement) - { - // Built-in services (switch mode only - in dictionary mode these are in _localResolvers) - writer.WriteLine("if(serviceType == typeof(IServiceProvider)) return this;"); - writer.WriteLine("if(serviceType == typeof(IServiceScopeFactory)) return this;"); - writer.WriteLine($"if(serviceType == typeof({container.ClassName})) return this;"); - writer.WriteLine(); - - // Cascading if statements - registrations already filtered (no open generics) - foreach(var kvp in groups.ByServiceTypeAndKey) - { - if(kvp.Key.Key is not null) - continue; - - var cached = kvp.Value[^1]; // Last wins - var reg = cached.Registration; - writer.WriteLine($"if(serviceType == typeof({reg.ServiceType.Name})) return {cached.ResolverMethodName}();"); - } - - writer.WriteLine(); - } - else - { - writer.WriteLine($"if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, {KeyedServiceAnyKey}), out var resolver))"); - writer.Indentation++; - writer.WriteLine("return resolver(this);"); - writer.Indentation--; - writer.WriteLine(); - } - - // Fallback - if(container.IntegrateServiceProvider) - { - writer.WriteLine("return _fallbackProvider?.GetService(serviceType);"); - } - else - { - writer.WriteLine("return null;"); - } - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes IKeyedServiceProvider implementation. - /// - private static void WriteIKeyedServiceProviderImplementation( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region IKeyedServiceProvider"); - writer.WriteLine(); - - writer.WriteLine("public object? GetKeyedService(Type serviceType, object? serviceKey)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine(); - - writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); - writer.WriteLine(); - - if(effectiveUseSwitchStatement) - { - // Use pre-computed hasKeyedServices flag - if(groups.HasKeyedServices) - { - // Use tuple pattern matching switch expression - writer.WriteLine("return (serviceType, key) switch"); - writer.WriteLine("{"); - writer.Indentation++; - - foreach(var kvp in groups.ByServiceTypeAndKey) - { - if(kvp.Key.Key is null) - continue; - - var cached = kvp.Value[^1]; // Last wins - writer.WriteLine($"(Type t, object k) when t == typeof({kvp.Key.ServiceType}) && Equals(k, {kvp.Key.Key}) => {cached.ResolverMethodName}(),"); - } - - // Fallback in switch default case - if(container.IntegrateServiceProvider) - { - writer.WriteLine("_ => _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null"); - } - else - { - writer.WriteLine("_ => null"); - } - - writer.Indentation--; - writer.WriteLine("};"); - } - else - { - // No keyed services, just return fallback - if(container.IntegrateServiceProvider) - { - writer.WriteLine("return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null;"); - } - else - { - writer.WriteLine("return null;"); - } - } - } - else - { - writer.WriteLine("if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver))"); - writer.Indentation++; - writer.WriteLine("return resolver(this);"); - writer.Indentation--; - writer.WriteLine(); - - // Fallback - if(container.IntegrateServiceProvider) - { - writer.WriteLine("return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null;"); - } - else - { - writer.WriteLine("return null;"); - } - } - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // GetRequiredKeyedService - writer.WriteLine("public object GetRequiredKeyedService(Type serviceType, object? serviceKey)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine("return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($\"No service for type '{serviceType}' with key '{serviceKey}' has been registered.\");"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes ISupportRequiredService implementation. - /// - private static void WriteISupportRequiredServiceImplementation(SourceWriter writer, ContainerModel container) - { - writer.WriteLine("#region ISupportRequiredService"); - writer.WriteLine(); - - writer.WriteLine("public object GetRequiredService(Type serviceType)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine("return GetService(serviceType) ?? throw new InvalidOperationException($\"No service for type '{serviceType}' has been registered.\");"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes generic service resolution extension methods matching - /// ServiceProviderServiceExtensions and ServiceProviderKeyedServiceExtensions signatures. - /// In dictionary mode, methods directly query _serviceResolvers for optimal performance. - /// In switch mode, methods delegate to non-generic counterparts. - /// - private static void WriteServiceProviderExtensions( - SourceWriter writer, - ContainerModel container, - bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region ServiceProvider Extensions"); - writer.WriteLine(); - - if(effectiveUseSwitchStatement) - { - WriteServiceProviderExtensionsSwitchMode(writer); - } - else - { - WriteServiceProviderExtensionsDictionaryMode(writer); - } - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes generic extension methods for switch/if-cascade mode. - /// These delegate to the non-generic methods. - /// - private static void WriteServiceProviderExtensionsSwitchMode(SourceWriter writer) - { - // GetService - writer.WriteLine("public T? GetService() where T : class"); - writer.Indentation++; - writer.WriteLine("=> GetService(typeof(T)) as T;"); - writer.Indentation--; - writer.WriteLine(); - - // GetRequiredService - writer.WriteLine("public T GetRequiredService() where T : notnull"); - writer.Indentation++; - writer.WriteLine("=> (T)GetRequiredService(typeof(T));"); - writer.Indentation--; - writer.WriteLine(); - - // GetServices - writer.WriteLine("public System.Collections.Generic.IEnumerable GetServices()"); - writer.Indentation++; - writer.WriteLine("=> (System.Collections.Generic.IEnumerable?)GetService(typeof(System.Collections.Generic.IEnumerable)) ?? [];"); - writer.Indentation--; - writer.WriteLine(); - - // GetKeyedService - writer.WriteLine("public T? GetKeyedService(object? serviceKey) where T : class"); - writer.Indentation++; - writer.WriteLine("=> GetKeyedService(typeof(T), serviceKey) as T;"); - writer.Indentation--; - writer.WriteLine(); - - // GetRequiredKeyedService - writer.WriteLine("public T GetRequiredKeyedService(object? serviceKey) where T : notnull"); - writer.Indentation++; - writer.WriteLine("=> (T)GetRequiredKeyedService(typeof(T), serviceKey);"); - writer.Indentation--; - writer.WriteLine(); - - // GetKeyedServices - writer.WriteLine("public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey)"); - writer.Indentation++; - writer.WriteLine("=> (System.Collections.Generic.IEnumerable?)GetKeyedService(typeof(System.Collections.Generic.IEnumerable), serviceKey) ?? [];"); - writer.Indentation--; - writer.WriteLine(); - } - - /// - /// Writes generic extension methods for dictionary mode. - /// These directly query _serviceResolvers FrozenDictionary for optimal performance. - /// - private static void WriteServiceProviderExtensionsDictionaryMode(SourceWriter writer) - { - // GetService - writer.WriteLine("public T? GetService() where T : class"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), {KeyedServiceAnyKey}), out var resolver)"); - writer.Indentation++; - writer.WriteLine("? resolver(this) as T"); - writer.WriteLine(": null;"); - writer.Indentation--; - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // GetRequiredService - writer.WriteLine("public T GetRequiredService() where T : notnull"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), {KeyedServiceAnyKey}), out var resolver)"); - writer.Indentation++; - writer.WriteLine("? (T)resolver(this)"); - writer.WriteLine(": throw new InvalidOperationException($\"No service for type '{typeof(T)}' has been registered.\");"); - writer.Indentation--; - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // GetServices - writer.WriteLine("public System.Collections.Generic.IEnumerable GetServices()"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), {KeyedServiceAnyKey}), out var resolver)"); - writer.Indentation++; - writer.WriteLine("? (System.Collections.Generic.IEnumerable)resolver(this)"); - writer.WriteLine(": [];"); - writer.Indentation--; - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // GetKeyedService - writer.WriteLine("public T? GetKeyedService(object? serviceKey) where T : class"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); - writer.WriteLine("return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver)"); - writer.Indentation++; - writer.WriteLine("? resolver(this) as T"); - writer.WriteLine(": null;"); - writer.Indentation--; - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // GetRequiredKeyedService - writer.WriteLine("public T GetRequiredKeyedService(object? serviceKey) where T : notnull"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); - writer.WriteLine("return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver)"); - writer.Indentation++; - writer.WriteLine("? (T)resolver(this)"); - writer.WriteLine(": throw new InvalidOperationException($\"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered.\");"); - writer.Indentation--; - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // GetKeyedServices - writer.WriteLine("public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); - writer.WriteLine("return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver)"); - writer.Indentation++; - writer.WriteLine("? (System.Collections.Generic.IEnumerable)resolver(this)"); - writer.WriteLine(": [];"); - writer.Indentation--; - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - } - - /// - /// Writes IServiceProviderIsService implementation. - /// - private static void WriteIServiceProviderIsServiceImplementation( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region IServiceProviderIsService"); - writer.WriteLine(); - - // IsService - writer.WriteLine("public bool IsService(Type serviceType)"); - writer.WriteLine("{"); - writer.Indentation++; - - if(effectiveUseSwitchStatement) - { - // Built-in services (switch mode only - in dictionary mode these are in _localResolvers) - writer.WriteLine("if(serviceType == typeof(IServiceProvider)) return true;"); - writer.WriteLine("if(serviceType == typeof(IServiceScopeFactory)) return true;"); - writer.WriteLine($"if(serviceType == typeof({container.ClassName})) return true;"); - writer.WriteLine(); - - foreach(var serviceType in groups.AllServiceTypes) - { - writer.WriteLine($"if(serviceType == typeof({serviceType})) return true;"); - } - writer.WriteLine(); - } - else - { - writer.WriteLine($"if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, {KeyedServiceAnyKey}))) return true;"); - writer.WriteLine(); - } - - if(container.IntegrateServiceProvider) - { - writer.WriteLine("return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType);"); - } - else - { - writer.WriteLine("return false;"); - } - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // IsKeyedService - writer.WriteLine("public bool IsKeyedService(Type serviceType, object? serviceKey)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine($"var key = serviceKey ?? {KeyedServiceAnyKey};"); - - if(!effectiveUseSwitchStatement) - { - writer.WriteLine(); - writer.WriteLine("if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true;"); - } - - writer.WriteLine(); - - if(container.IntegrateServiceProvider) - { - writer.WriteLine("return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey);"); - } - else - { - writer.WriteLine("return false;"); - } - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes IServiceScopeFactory implementation. - /// - private static void WriteIServiceScopeFactoryImplementation(SourceWriter writer, ContainerModel container) - { - writer.WriteLine("#region IServiceScopeFactory"); - writer.WriteLine(); - - writer.WriteLine("public IServiceScope CreateScope()"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ThrowIfDisposed();"); - writer.WriteLine($"return new {container.ClassName}(this);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine("public AsyncServiceScope CreateAsyncScope() => new(CreateScope());"); - writer.WriteLine(); - writer.WriteLine("IServiceProvider IServiceScope.ServiceProvider => this;"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes IIocContainer implementation. - /// - private static void WriteIIocContainerImplementation( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region IIocContainer"); - writer.WriteLine(); - - if(effectiveUseSwitchStatement) - { - writer.WriteLine($"public static IReadOnlyCollection>> Resolvers => _localResolvers;"); - } - else - { - writer.WriteLine($"public static IReadOnlyCollection>> Resolvers => _serviceResolvers;"); - } - writer.WriteLine(); - - // Write _localResolvers as static field - writer.WriteLine($"private static readonly KeyValuePair>[] _localResolvers ="); - writer.WriteLine("["); - writer.Indentation++; - - // Built-in services: IServiceProvider, IServiceScopeFactory, and the container itself - writer.WriteLine($"new(new ServiceIdentifier(typeof(IServiceProvider), {KeyedServiceAnyKey}), static c => c),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(IServiceScopeFactory), {KeyedServiceAnyKey}), static c => c),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof({container.ContainerTypeName}), {KeyedServiceAnyKey}), static c => c),"); - - // Use ByServiceTypeAndKey to include all service types (already filtered for non-open-generics) - foreach(var kvp in groups.ByServiceTypeAndKey) - { - var cached = kvp.Value[^1]; // Last wins - var keyExpr = kvp.Key.Key ?? KeyedServiceAnyKey; - - // Determine resolver expression based on registration type - string resolverExpr; - if(cached.Registration.Instance is not null) - { - // Instance registration: directly return the instance - resolverExpr = $"static _ => {cached.Registration.Instance}"; - } - else if(cached.IsAsyncInit) - { - // Async-init services: expose Task from GetService. - // Singleton/Scoped: call the routing async resolver method. - // Transient: call the async creation method directly (no caching). - if(cached.Registration.Lifetime == ServiceLifetime.Transient) - { - var createMethodName = GetAsyncCreateMethodName(cached.ResolverMethodName); - resolverExpr = $"static c => c.{createMethodName}()"; - } - else - { - var asyncMethodName = GetAsyncResolverMethodName(cached.ResolverMethodName); - resolverExpr = $"static c => c.{asyncMethodName}()"; - } - } - else if(cached.IsEager) - { - // Eager services: directly access the field - resolverExpr = $"static c => c.{cached.FieldName}!"; - } - else - { - // Lazy services: call the Get method - resolverExpr = $"static c => c.{cached.ResolverMethodName}()"; - } - - writer.WriteLine($"new(new ServiceIdentifier(typeof({kvp.Key.ServiceType}), {keyExpr}), {resolverExpr}),"); - } - - // Add IEnumerable, IReadOnlyCollection, ICollection, IReadOnlyList, IList, T[] entries for collection service types - foreach(var serviceType in groups.CollectionServiceTypes) - { - var methodName = GetArrayResolverMethodName(serviceType); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IEnumerable<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyCollection<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.ICollection<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IReadOnlyList<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof(global::System.Collections.Generic.IList<{serviceType}>), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); - writer.WriteLine($"new(new ServiceIdentifier(typeof({serviceType}[]), {KeyedServiceAnyKey}), static c => c.{methodName}()),"); - } - - // Add KeyValuePair collection entries for keyed services consumed as KVP/Dictionary - WriteContainerKvpLocalResolverEntries(writer, container.ContainerTypeName, groups.KvpEntries); - - // Add Lazy/Func wrapper entries for consumers that depend on Lazy/Func - WriteContainerLazyLocalResolverEntries(writer, container.ContainerTypeName, groups.LazyEntries); - WriteContainerFuncLocalResolverEntries(writer, container.ContainerTypeName, groups.FuncEntries); - - writer.Indentation--; - writer.WriteLine("];"); - - // Write _serviceResolvers as static field (only when not using switch statement) - if(!effectiveUseSwitchStatement) - { - writer.WriteLine(); - if(container.ImportedModules.Length > 0) - { - // Combine with imported modules - wrap module resolvers to pass the correct module instance - // Use static access (module.Name is the fully qualified type name) for static abstract Resolvers - writer.WriteLine($"private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers ="); - writer.Indentation++; - - var isFirst = true; - foreach(var module in container.ImportedModules) - { - var fieldName = GetModuleFieldName(module.Name); - var source = $"{module.Name}.Resolvers.Select(static kvp => new KeyValuePair>(kvp.Key, c => kvp.Value(c.{fieldName})))"; - - if(isFirst) - { - writer.WriteLine(source); - isFirst = false; - } - else - { - writer.WriteLine($".Concat({source})"); - } - } - - writer.WriteLine(".Concat(_localResolvers)"); - writer.WriteLine(".ToFrozenDictionary();"); - writer.Indentation--; - } - else - { - writer.WriteLine($"private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary();"); - } - } - - writer.WriteLine(); - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes IServiceProviderFactory implementation. - /// - private static void WriteIServiceProviderFactoryImplementation(SourceWriter writer, ContainerModel container) - { - writer.WriteLine("#region IServiceProviderFactory"); - writer.WriteLine(); - - writer.WriteLine("/// "); - writer.WriteLine("/// Creates a new container builder (returns the same IServiceCollection)."); - writer.WriteLine("/// "); - writer.WriteLine("public IServiceCollection CreateBuilder(IServiceCollection services) => services;"); - writer.WriteLine(); - - writer.WriteLine("/// "); - writer.WriteLine("/// Creates the service provider from the built IServiceCollection."); - writer.WriteLine("/// "); - writer.WriteLine("public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder);"); - writer.WriteLine($"return new {container.ClassName}(fallbackProvider);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes Disposal implementation (IDisposable and IAsyncDisposable). - /// - private static void WriteDisposalImplementation( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups) - { - writer.WriteLine("#region Disposal"); - writer.WriteLine(); - - // IDisposable.Dispose - writer.WriteLine("public void Dispose()"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(Interlocked.Exchange(ref _disposed, 1) != 0) return;"); - writer.WriteLine(); - WriteDisposalBody(writer, container, groups, isAsync: false); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // IAsyncDisposable.DisposeAsync - writer.WriteLine("public async ValueTask DisposeAsync()"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(Interlocked.Exchange(ref _disposed, 1) != 0) return;"); - writer.WriteLine(); - WriteDisposalBody(writer, container, groups, isAsync: true); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - var hasAsyncInitServices = groups.ReversedSingletonsForDisposal.Any(static c => c.IsAsyncInit) - || groups.ReversedScopedForDisposal.Any(static c => c.IsAsyncInit); - WriteDisposalHelperMethods(writer, hasAsyncInitServices); - - writer.WriteLine("#endregion"); - } - - /// - /// Writes IControllerActivator implementation using ActivatorUtilities with ObjectFactory caching. - /// - private static void WriteIControllerActivatorImplementation(SourceWriter writer, ContainerModel container) - { - writer.WriteLine("#region IControllerActivator"); - writer.WriteLine(); - - // Static cache field - writer.WriteLine("private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _controllerFactoryCache = new();"); - writer.WriteLine("private static global::Microsoft.Extensions.DependencyInjection.ObjectFactory CreateControllerFactory("); - writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers("); - writer.WriteLine(" global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] global::System.Type t)"); - writer.WriteLine(" => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateFactory(t, global::System.Type.EmptyTypes);"); - writer.WriteLine(); - - // Create method - writer.WriteLine("object global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator.Create(global::Microsoft.AspNetCore.Mvc.ControllerContext controllerContext)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(controllerContext);"); - writer.WriteLine("var controllerType = controllerContext.ActionDescriptor.ControllerTypeInfo.AsType();"); - writer.WriteLine("var instance = GetService(controllerType);"); - writer.WriteLine("if (instance is not null) return instance;"); - writer.WriteLine(); - writer.WriteLine("if (!_controllerFactoryCache.TryGetValue(controllerType, out var factory))"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("factory = CreateControllerFactory(controllerType);"); - writer.WriteLine("_controllerFactoryCache.TryAdd(controllerType, factory);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine("return factory(this, []);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // Release method - writer.WriteLine("void global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator.Release(global::Microsoft.AspNetCore.Mvc.ControllerContext context, object controller)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(context);"); - writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(controller);"); - writer.WriteLine("if (controller is global::System.IDisposable disposable)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("disposable.Dispose();"); - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // ReleaseAsync method - writer.WriteLine("global::System.Threading.Tasks.ValueTask global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator.ReleaseAsync(global::Microsoft.AspNetCore.Mvc.ControllerContext context, object controller)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(context);"); - writer.WriteLine("global::System.ArgumentNullException.ThrowIfNull(controller);"); - writer.WriteLine("if (controller is global::System.IAsyncDisposable asyncDisposable)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("return asyncDisposable.DisposeAsync();"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine("((global::Microsoft.AspNetCore.Mvc.Controllers.IControllerActivator)this).Release(context, controller);"); - writer.WriteLine("return default;"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - /// - /// Writes IComponentActivator implementation using ActivatorUtilities with ObjectFactory caching. - /// - private static void WriteIComponentActivatorImplementation(SourceWriter writer, ContainerModel container) - { - writer.WriteLine("#region IComponentActivator"); - writer.WriteLine(); - - // Static cache field - writer.WriteLine("private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _componentFactoryCache = new();"); - writer.WriteLine("private static global::Microsoft.Extensions.DependencyInjection.ObjectFactory CreateComponentFactory("); - writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers("); - writer.WriteLine(" global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] global::System.Type t)"); - writer.WriteLine(" => global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateFactory(t, global::System.Type.EmptyTypes);"); - writer.WriteLine(); - - // CreateInstance method - writer.WriteLine("global::Microsoft.AspNetCore.Components.IComponent global::Microsoft.AspNetCore.Components.IComponentActivator.CreateInstance([global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.PublicConstructors)] global::System.Type componentType)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if (!typeof(global::Microsoft.AspNetCore.Components.IComponent).IsAssignableFrom(componentType))"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("throw new global::System.ArgumentException($\"The type {componentType.FullName} does not implement {nameof(global::Microsoft.AspNetCore.Components.IComponent)}.\", nameof(componentType));"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine("var instance = GetService(componentType);"); - writer.WriteLine("if (instance is global::Microsoft.AspNetCore.Components.IComponent component) return component;"); - writer.WriteLine(); - writer.WriteLine("if (!_componentFactoryCache.TryGetValue(componentType, out var factory))"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("factory = CreateComponentFactory(componentType);"); - writer.WriteLine("_componentFactoryCache.TryAdd(componentType, factory);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine("return (global::Microsoft.AspNetCore.Components.IComponent)factory(this, []);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - private static void WriteIComponentPropertyActivatorImplementation(SourceWriter writer, ContainerModel container, bool effectiveUseSwitchStatement) - { - writer.WriteLine("#region IComponentPropertyActivator"); - writer.WriteLine(); - - // Static cache field - writer.WriteLine("private static readonly global::System.Collections.Concurrent.ConcurrentDictionary> _propertyActivatorCache = new();"); - writer.WriteLine(); - - // GetActivator method - explicit interface implementation - writer.WriteLine("global::System.Action global::Microsoft.AspNetCore.Components.IComponentPropertyActivator.GetActivator("); - writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All)] global::System.Type componentType)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("if (!_propertyActivatorCache.TryGetValue(componentType, out var activator))"); - writer.WriteLine("{"); - writer.Indentation++; - - // No-op optimization: only when also implementing IComponentActivator - if(container.ImplementComponentActivator) - { - if(effectiveUseSwitchStatement) - { - writer.WriteLine("activator = global::System.Array.Exists(_localResolvers, e => e.Key.Equals(new ServiceIdentifier(componentType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey)))"); - } - else - { - writer.WriteLine("activator = _serviceResolvers.ContainsKey(new ServiceIdentifier(componentType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))"); - } - writer.WriteLine(" ? static (_, _) => { }"); - writer.WriteLine(" : CreateComponentPropertyInjector(componentType);"); - } - else - { - writer.WriteLine("activator = CreateComponentPropertyInjector(componentType);"); - } - - writer.WriteLine("_propertyActivatorCache.TryAdd(componentType, activator);"); - writer.Indentation--; - writer.WriteLine("}"); - - writer.WriteLine("return activator;"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // CreateComponentPropertyInjector - reflection fallback method - WriteCreateComponentPropertyInjectorMethod(writer); - - writer.WriteLine("#endregion"); - writer.WriteLine(); - } - - private static void WriteCreateComponentPropertyInjectorMethod(SourceWriter writer) - { - writer.WriteLine("private static global::System.Action CreateComponentPropertyInjector("); - writer.WriteLine(" [global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembers(global::System.Diagnostics.CodeAnalysis.DynamicallyAccessedMemberTypes.All)] global::System.Type componentType)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("const global::System.Reflection.BindingFlags flags = global::System.Reflection.BindingFlags.Instance | global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic;"); - writer.WriteLine("global::System.Collections.Generic.List<(global::System.Reflection.PropertyInfo Property, object? Key)>? injectables = null;"); - writer.WriteLine(); - - // Walk inheritance chain - writer.WriteLine("for (var type = componentType; type is not null; type = type.BaseType)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("foreach (var property in type.GetProperties(flags))"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("if (property.DeclaringType != type) continue;"); - writer.WriteLine("var injectAttr = property.GetCustomAttributes(true)"); - writer.WriteLine(" .FirstOrDefault(a => a.GetType().Name is \"InjectAttribute\" or \"IocInjectAttribute\");"); - writer.WriteLine("if (injectAttr is null) continue;"); - writer.WriteLine(); - writer.WriteLine("var keyProp = injectAttr.GetType().GetProperty(\"Key\");"); - writer.WriteLine("var key = keyProp?.GetValue(injectAttr);"); - writer.WriteLine("injectables ??= new();"); - writer.WriteLine("injectables.Add((property, key));"); - - writer.Indentation--; - writer.WriteLine("}"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // Return no-op if no injectables found - writer.WriteLine("if (injectables is null) return static (_, _) => { };"); - writer.WriteLine(); - - // Return injection delegate - writer.WriteLine("return (serviceProvider, component) =>"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("foreach (var (property, serviceKey) in injectables)"); - writer.WriteLine("{"); - writer.Indentation++; - - writer.WriteLine("object? value;"); - writer.WriteLine("if (serviceKey is not null)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if (serviceProvider is not global::Microsoft.Extensions.DependencyInjection.IKeyedServiceProvider keyedProvider)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("throw new global::System.InvalidOperationException($\"Cannot provide a value for property '{property.Name}' on type '{componentType.FullName}'. The service provider does not implement 'IKeyedServiceProvider'.\");"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("value = keyedProvider.GetRequiredKeyedService(property.PropertyType, serviceKey);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("else"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("value = serviceProvider.GetService(property.PropertyType) ?? throw new global::System.InvalidOperationException($\"Cannot provide a value for property '{property.Name}' on type '{componentType.FullName}'. There is no registered service of type '{property.PropertyType}'.\");"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("property.SetValue(component, value);"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("};"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - } - - /// - /// Writes the __HotReloadHandler nested class for hot reload cache invalidation. - /// - private static void WriteHotReloadHandler(SourceWriter writer, ContainerModel container) - { - writer.WriteLine("[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]"); - writer.WriteLine("internal static class __HotReloadHandler"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("public static void ClearCache(global::System.Type[]? _)"); - writer.WriteLine("{"); - writer.Indentation++; - if(container.ImplementComponentActivator) - { - writer.WriteLine("_componentFactoryCache.Clear();"); - } - if(container.ImplementComponentPropertyActivator) - { - writer.WriteLine("_propertyActivatorCache.Clear();"); - } - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes the disposal body for both sync and async disposal methods. - /// - private static void WriteDisposalBody( - SourceWriter writer, - ContainerModel container, - ContainerRegistrationGroups groups, - bool isAsync) - { - // Dispose scoped services if this is a scope - writer.WriteLine("if(!_isRootScope)"); - writer.WriteLine("{"); - writer.Indentation++; - - WriteDisposalCalls(writer, groups.ReversedScopedForDisposal, container.ImportedModules, container.ThreadSafeStrategy, isAsync); - writer.WriteLine("return;"); - - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // Root scope disposal - WriteDisposalCalls(writer, groups.ReversedSingletonsForDisposal, container.ImportedModules, container.ThreadSafeStrategy, isAsync); - } - - /// - /// Writes disposal calls for services and modules. - /// - private static void WriteDisposalCalls( - SourceWriter writer, - IEnumerable services, - ImmutableEquatableArray modules, - ThreadSafeStrategy strategy, - bool isAsync) - { - var (serviceMethod, moduleMethod) = isAsync - ? ("await DisposeServiceAsync", "await {0}.DisposeAsync()") - : ("DisposeService", "{0}.Dispose()"); - - foreach(var cached in services) - { - // Skip instance registrations - they are externally managed and should not be disposed by the container - if(cached.Registration.Instance is not null) - continue; - - var effectiveStrategy = GetEffectiveThreadSafeStrategy(strategy, cached.IsAsyncInit); - - writer.WriteLine($"{serviceMethod}({cached.FieldName});"); - - // Dispose SemaphoreSlim if using SemaphoreSlim strategy (only for non-eager services) - if(effectiveStrategy == ThreadSafeStrategy.SemaphoreSlim && !cached.IsEager) - { - writer.WriteLine($"{cached.FieldName}Semaphore.Dispose();"); - } - } - - foreach(var module in modules) - { - var fieldName = GetModuleFieldName(module.Name); - writer.WriteLine(string.Format(moduleMethod, fieldName) + ";"); - } - } - - /// - /// Writes the static helper methods for disposal. - /// - private static void WriteDisposalHelperMethods(SourceWriter writer, bool hasAsyncInitServices) - { - // Helper method to throw ObjectDisposedException if disposed - writer.WriteLine("private void ThrowIfDisposed()"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("ObjectDisposedException.ThrowIf(_disposed != 0, GetType());"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // Helper method for async disposal - writer.WriteLine("private static async ValueTask DisposeServiceAsync(object? service)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync();"); - writer.WriteLine("else if(service is IDisposable disposable) disposable.Dispose();"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - // Helper method for sync disposal - writer.WriteLine("private static void DisposeService(object? service)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(service is IDisposable disposable) disposable.Dispose();"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - if(hasAsyncInitServices) - { - // Overload for async-init services stored as Task? - writer.WriteLine("private static async ValueTask DisposeServiceAsync(Task? task)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(task is { IsCompletedSuccessfully: true })"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("try"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("await DisposeServiceAsync(await task);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("catch(Exception ex)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - - writer.WriteLine("private static void DisposeService(Task? task)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("if(task is { IsCompletedSuccessfully: true })"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("try"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult());"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("catch(Exception ex)"); - writer.WriteLine("{"); - writer.Indentation++; - writer.WriteLine("global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex);"); - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("}"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine(); - } - } - - /// - /// Gets the array resolver method name for IEnumerable<T>, IReadOnlyCollection<T>, IReadOnlyList<T>, T[] resolution. - /// - private static string GetArrayResolverMethodName(string serviceType) - { - var baseName = GetSafeIdentifier(serviceType); - return $"GetAll{baseName}Array"; - } - - /// - /// Gets the field name for an imported module. - /// - private static string GetModuleFieldName(string moduleName) - { - var baseName = GetSafeIdentifier(moduleName); - return $"_{char.ToLowerInvariant(baseName[0])}{baseName[1..]}"; - } -} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs deleted file mode 100644 index 19b97f6..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GenerateRegisterOutput.cs +++ /dev/null @@ -1,1910 +0,0 @@ -ο»Ώusing static SourceGen.Ioc.SourceGenerator.Models.Constants; - -namespace SourceGen.Ioc; - -partial class IocSourceGenerator -{ - private static void GenerateRegisterOutput( - in SourceProductionContext ctx, - ImmutableEquatableArray registrations, - string rootNamespace, - string assemblyName, - MsBuildProperties msbuildProps) - { - var features = msbuildProps.Features; - if((features & IocFeatures.Register) == 0 || registrations.Length == 0) - return; - - // Generate source - var source = GenerateExtensionMethodSource(registrations, rootNamespace, assemblyName, msbuildProps.CustomIocName, features); - ctx.AddSource($"{assemblyName}.ServiceRegistration.g.cs", source); - } - - private static string GenerateExtensionMethodSource( - ImmutableEquatableArray registrations, - string rootNamespace, - string assemblyName, - string? customIocName, - IocFeatures features) - { - // Use custom IoC name if provided, otherwise use assembly name - var methodBaseName = !string.IsNullOrWhiteSpace(customIocName) - ? GetSafeIdentifier(customIocName!) - : GetSafeIdentifier(assemblyName); - - var tagKeyCache = new Dictionary, string>(); - foreach(var regWithTags in registrations) - { - var tags = regWithTags.Tags; - if(tagKeyCache.ContainsKey(tags)) - { - continue; - } - - tagKeyCache[tags] = tags.Length > 0 - ? string.Join(",", tags.OrderBy(static t => t, StringComparer.Ordinal)) - : string.Empty; - } - - // Group registrations by their tag conditions for generating conditional blocks - // Key: sorted comma-separated tags (empty string for no tags) - // Value: list of registrations - var tagGroups = new Dictionary>(StringComparer.Ordinal); - var shouldFilterInjection = !IocFeaturesHelper.HasAllInjectionFeatures(features); - - foreach(var regWithTags in registrations) - { - var registration = shouldFilterInjection - ? FilterRegistrationForFeatures(regWithTags.Registration, features) - : regWithTags.Registration; - var tags = regWithTags.Tags; - var tagKey = tagKeyCache[tags]; - - if(!tagGroups.TryGetValue(tagKey, out var group)) - { - group = new List(); - tagGroups[tagKey] = group; - } - - group.Add(registration); - } - - // Build the final output - var mainWriter = new SourceWriter(); - - mainWriter.WriteLine(AutoGeneratedHeader); - mainWriter.WriteLine(NullableEnable); - mainWriter.WriteLine(); - mainWriter.WriteLine("using Microsoft.Extensions.DependencyInjection;"); - mainWriter.WriteLine("using System.Collections.Generic;"); - mainWriter.WriteLine("using System.Linq;"); - mainWriter.WriteLine(); - - mainWriter.WriteLine($"namespace {GetSafeNamespace(rootNamespace)}"); - mainWriter.WriteLine("{"); - mainWriter.Indentation++; - - mainWriter.WriteLine("/// "); - mainWriter.WriteLine($"/// Extension methods for registering services from {assemblyName}."); - mainWriter.WriteLine("/// "); - mainWriter.WriteLine($"public static class {methodBaseName}ServiceCollectionExtensions"); - mainWriter.WriteLine("{"); - mainWriter.Indentation++; - - // Generate single method with tags parameter - mainWriter.WriteLine("/// "); - mainWriter.WriteLine("/// Registers services. Services with tags are only registered when matching tags are passed."); - mainWriter.WriteLine("/// "); - mainWriter.WriteLine("/// The service collection."); - mainWriter.WriteLine("/// Optional tags to filter which services to register."); - mainWriter.WriteLine($"public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection Add{methodBaseName}(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags)"); - mainWriter.WriteLine("{"); - mainWriter.Indentation++; - - // Collect standalone Lazy/Func registrations from consumer dependencies - var lazyEntries = CollectLazyEntries(registrations); - var funcEntries = CollectFuncEntries(registrations); - - // Group Lazy entries by tag key to emit them within the correct conditional blocks - var lazyByTagKey = new Dictionary>(StringComparer.Ordinal); - foreach(var entry in lazyEntries) - { - var tagKey = tagKeyCache[entry.Tags]; - - if(!lazyByTagKey.TryGetValue(tagKey, out var list)) - { - list = []; - lazyByTagKey[tagKey] = list; - } - list.Add(entry); - } - - // Group Func entries by tag key to emit them within the correct conditional blocks - var funcByTagKey = new Dictionary>(StringComparer.Ordinal); - foreach(var entry in funcEntries) - { - var tagKey = tagKeyCache[entry.Tags]; - - if(!funcByTagKey.TryGetValue(tagKey, out var list)) - { - list = []; - funcByTagKey[tagKey] = list; - } - list.Add(entry); - } - - // Collect KVP registrations and group by tag key - var kvpEntries = CollectKeyValuePairEntries(registrations); - var kvpByTagKey = new Dictionary>(StringComparer.Ordinal); - foreach(var entry in kvpEntries) - { - var tagKey = tagKeyCache[entry.Tags]; - - if(!kvpByTagKey.TryGetValue(tagKey, out var list)) - { - list = []; - kvpByTagKey[tagKey] = list; - } - list.Add(entry); - } - - // Write registrations grouped by tag conditions - // Sort groups for deterministic output: no tags first, then by tag key - var sortedGroups = tagGroups - .OrderBy(static kvp => kvp.Key, StringComparer.Ordinal) - .ToList(); - - // Compute the set of service type names (and impl type names) that have async-init injection. - // Used by Task wrapper resolution to distinguish async-init vs sync-only services. - var asyncInitServiceTypeSet = new HashSet(StringComparer.Ordinal); - foreach(var group in tagGroups.Values) - { - foreach(var reg in group) - { - if(reg.InjectionMembers.Any(m => m.MemberType == InjectionMemberType.AsyncMethod)) - { - asyncInitServiceTypeSet.Add(reg.ServiceType.Name); - asyncInitServiceTypeSet.Add(reg.ImplementationType.Name); - } - } - } - HashSet? asyncInitServiceTypeNames = asyncInitServiceTypeSet.Count > 0 ? asyncInitServiceTypeSet : null; - - bool isFirstGroup = true; - foreach(var kvp in sortedGroups) - { - var groupKey = kvp.Key; - var groupRegistrations = kvp.Value; - - var tagList = string.IsNullOrEmpty(groupKey) - ? [] - : groupKey.Split(','); - - // Get matching wrapper entries for this tag group - lazyByTagKey.TryGetValue(groupKey, out var groupLazyEntries); - funcByTagKey.TryGetValue(groupKey, out var groupFuncEntries); - kvpByTagKey.TryGetValue(groupKey, out var groupKvpEntries); - - if(!isFirstGroup) - { - mainWriter.WriteLine(); - } - isFirstGroup = false; - - if(tagList.Length == 0) - { - // No tags - only register when no tags passed (mutually exclusive model) - WriteNoTagConditionalBlock(mainWriter, groupRegistrations, groupLazyEntries, groupFuncEntries, groupKvpEntries, asyncInitServiceTypeNames); - } - else - { - // Has tags - only register when tags match - WriteConditionalTagBlock(mainWriter, tagList, groupRegistrations, groupLazyEntries, groupFuncEntries, groupKvpEntries, asyncInitServiceTypeNames); - } - } - - mainWriter.WriteLine(); - mainWriter.WriteLine("return services;"); - - mainWriter.Indentation--; - mainWriter.WriteLine("}"); - - mainWriter.Indentation--; - mainWriter.WriteLine("}"); - - mainWriter.Indentation--; - mainWriter.WriteLine("}"); - - return mainWriter.ToString(); - } - - /// - /// Writes a group of registrations without any conditional wrapper. - /// - private static void WriteRegistrationGroup(SourceWriter writer, List registrations, HashSet? asyncInitServiceTypeNames = null) - { - foreach(var registration in registrations) - { - WriteRegistration(writer, registration, asyncInitServiceTypeNames); - } - } - - /// - /// Writes a conditional block for tag-based registration. - /// Services with tags are only registered when the passed tags match. - /// - private static void WriteConditionalTagBlock( - SourceWriter writer, - string[] tags, - List registrations, - List? lazyEntries = null, - List? funcEntries = null, - List? kvpEntries = null, - HashSet? asyncInitServiceTypeNames = null) - { - // Build the condition - only register when tags match - var tagConditions = tags.Select(static tag => $"tags.Contains(\"{tag}\")"); - var condition = string.Join(" || ", tagConditions); - - writer.WriteLine($"if ({condition})"); - writer.WriteLine("{"); - writer.Indentation++; - - foreach(var registration in registrations) - { - WriteRegistration(writer, registration, asyncInitServiceTypeNames); - } - - WriteLazyRegistrations(writer, lazyEntries); - WriteFuncRegistrations(writer, funcEntries); - WriteKvpRegistrations(writer, kvpEntries); - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes a conditional block for services without tags. - /// Services without tags are only registered when no tags are passed (mutually exclusive model). - /// - private static void WriteNoTagConditionalBlock( - SourceWriter writer, - List registrations, - List? lazyEntries = null, - List? funcEntries = null, - List? kvpEntries = null, - HashSet? asyncInitServiceTypeNames = null) - { - writer.WriteLine("if (!tags.Any())"); - writer.WriteLine("{"); - writer.Indentation++; - - foreach(var registration in registrations) - { - WriteRegistration(writer, registration, asyncInitServiceTypeNames); - } - - WriteLazyRegistrations(writer, lazyEntries); - WriteFuncRegistrations(writer, funcEntries); - WriteKvpRegistrations(writer, kvpEntries); - - writer.Indentation--; - writer.WriteLine("}"); - } - - /// - /// Writes the start of a service registration lambda for keyed or non-keyed registrations. - /// - private static void WriteServiceRegistrationLambdaStart( - SourceWriter writer, - string lifetime, - string serviceTypeName, - string? registrationKey) - { - if(registrationKey is not null) - { - writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}>({registrationKey}, ({IServiceProviderGlobalTypeName} sp, object? key) =>"); - return; - } - - writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) =>"); - } - - private static void WriteRegistration(SourceWriter writer, ServiceRegistrationModel registration, HashSet? asyncInitServiceTypeNames = null) - { - var lifetime = registration.Lifetime.Name; - var serviceTypeName = registration.ServiceType.Name; - var implTypeName = registration.ImplementationType.Name; - - // Check if this registration uses Factory or Instance - // Note: Factory and Instance cannot be used with open generics - they require concrete types - bool hasFactory = registration.Factory is not null && !registration.IsOpenGeneric; - bool hasInstance = registration.Instance is not null && !registration.IsOpenGeneric; - - // Handle Factory registration first (takes precedence) - if(hasFactory) - { - WriteFactoryMethodRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); - return; - } - - // Handle Instance registration (only for Singleton) - if(hasInstance) - { - if(registration.Lifetime == ServiceLifetime.Singleton) - WriteInstanceRegistration(writer, registration); - - return; - } - - // Check if service type is different from implementation type (registering interface/base class) - bool isServiceTypeRegistration = serviceTypeName != implTypeName; - - // Check if this registration has decorators and is not the implementation type itself - bool hasDecorators = registration.Decorators.Length > 0 && isServiceTypeRegistration; - - if(hasDecorators) - { - if(registration.IsOpenGeneric) - { - // Open generic service types with decorators: do not generate factory lambdas with type parameters. - // Fall through to open generic registration using typeof() syntax. - } - else - { - WriteDecoratorRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); - return; - } - } - - // Check if this registration has injection members (properties, fields, methods with [Inject]) - bool hasInjectionMembers = registration.InjectionMembers.Length > 0; - - // Check if the constructor was selected by [Inject] attribute (requires factory for proper constructor selection) - bool hasInjectConstructor = registration.ImplementationType.HasInjectConstructor; - - // Check constructor parameters for special handling requirements: - // - [Inject] attribute on parameters: MS.DI doesn't recognize it - // - Non-IEnumerable collection types (IList, T[], etc.): MS.DI only supports IEnumerable - // - Has default value and is resolvable type: needs conditional handling (GetService + null check) - // - Nullable Lazy?/Func? (direct, non-nested): needs factory for optional resolution via GetService - // Note: [FromKeyedServices], [ServiceKey], IServiceProvider are handled by MS.DI automatically - var constructorParams = registration.ImplementationType.ConstructorParameters; - bool hasSpecialConstructorParams = constructorParams?.Any(static p => - p.HasInjectAttribute || - p.Type.NeedsWrapperResolution || - (p.IsNullable && p.Type is LazyTypeData { InstanceType: not WrapperTypeData } or FuncTypeData { ReturnType: not WrapperTypeData }) || - p.HasDefaultValue) == true; - - // Determine if factory construction is needed: - // - Injection members (properties, fields, methods) always require factory - // - Constructor with [Inject] attribute requires factory (to use correct constructor) - // - Constructor parameters with special handling (see above) - bool needsFactoryConstruction = hasInjectionMembers || hasInjectConstructor || hasSpecialConstructorParams; - - // For non-open-generic, service type registrations (interface/base class): - // Always use forwarding to implementation type to ensure single instance per scope/lifetime - // This generates: sp => sp.GetRequiredService() - // When impl is async-init, forwards Task β†’ Task instead. - if(!registration.IsOpenGeneric && isServiceTypeRegistration) - { - // Service type registration (interface/base class) forwards to implementation - WriteServiceTypeForwardingRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); - return; - } - - if(needsFactoryConstruction && !registration.IsOpenGeneric) - { - // Self registration with injection members or constructor params with special handling - generate factory method - bool hasAsyncInjectionMembers = registration.InjectionMembers.Any(m => m.MemberType == InjectionMemberType.AsyncMethod); - if(hasAsyncInjectionMembers) - { - WriteAsyncInjectionRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); - } - else - { - WriteInjectionRegistration(writer, registration, lifetime, asyncInitServiceTypeNames); - } - return; - } - - if(registration.IsOpenGeneric) - { - // Open generic registration requires typeof() syntax - var serviceTypeOf = ConvertToTypeOf(registration.ServiceType); - var implTypeOf = ConvertToTypeOf(registration.ImplementationType); - - if(registration.Key is not null) - { - writer.WriteLine($"services.AddKeyed{lifetime}({serviceTypeOf}, {registration.Key}, {implTypeOf});"); - return; - } - - writer.WriteLine($"services.Add{lifetime}({serviceTypeOf}, {implTypeOf});"); - return; - } - - if(registration.Key is not null) - { - // Keyed registration - writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}, {registration.ImplementationType.Name}>({registration.Key});"); - return; - } - - // Non-keyed registration - writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}, {registration.ImplementationType.Name}>();"); - } - - /// - /// Writes registration code using a factory method specified in the attribute. - /// - /// - /// Supports factory methods with different parameter combinations: - /// - /// No parameters: services.AddSingleton<IService>(Factory()); - /// IServiceProvider only: services.AddSingleton<IService>(sp => Factory(sp)); - /// object key only (keyed): services.AddKeyedSingleton<IService>("key", (sp, key) => Factory(key)); - /// Both (keyed): services.AddKeyedSingleton<IService>("key", (sp, key) => Factory(sp, key)); - /// Additional parameters: resolved from IServiceProvider using the same logic as [IocInject] methods - /// Generic factory with [IocGenericFactory]: services.AddSingleton<IHandler<Request>>(sp => Factory<Request>(sp)); - /// - /// If the factory return type differs from the service type, adds a cast. - /// - private static void WriteFactoryMethodRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) - { - var serviceTypeName = registration.ServiceType.Name; - var factory = registration.Factory!; - var factoryPath = factory.Path; - var hasServiceProvider = factory.HasServiceProvider; - var hasKey = factory.HasKey; - var returnTypeName = factory.ReturnTypeName; - var additionalParameters = factory.AdditionalParameters; - bool isKeyedRegistration = registration.Key is not null; - - // Build generic type arguments for generic factory methods - var genericTypeArgs = BuildGenericFactoryTypeArgs(factory, registration.ServiceType); - - // If factory is generic but we couldn't resolve type arguments (e.g., duplicate placeholder types), - // skip this registration entirely - if(factory.TypeParameterCount > 0 && genericTypeArgs is null) - { - return; - } - - // Check if we have additional parameters that need resolution - bool hasAdditionalParameters = additionalParameters.Length > 0; - - if(hasAdditionalParameters) - { - // Use multi-line lambda format for readability - WriteFactoryMethodRegistrationWithAdditionalParams( - writer, registration, lifetime, serviceTypeName, factoryPath, - hasServiceProvider, hasKey, returnTypeName, additionalParameters, isKeyedRegistration, genericTypeArgs, asyncInitServiceTypeNames); - return; - } - - // Simple case: no additional parameters - var factoryInvocation = BuildFactoryInvocation( - factoryPath, - genericTypeArgs, - hasServiceProvider, - hasKey, - registration.Key, - [], - returnTypeName, - serviceTypeName); - - WriteFactoryRegistrationLine(writer, lifetime, serviceTypeName, registration.Key, factoryInvocation); - } - - /// - /// Writes factory method registration with additional parameters that need to be resolved from IServiceProvider. - /// - private static void WriteFactoryMethodRegistrationWithAdditionalParams( - SourceWriter writer, - ServiceRegistrationModel registration, - string lifetime, - string serviceTypeName, - string factoryPath, - bool hasServiceProvider, - bool hasKey, - string? returnTypeName, - ImmutableEquatableArray additionalParameters, - bool isKeyedRegistration, - string? genericTypeArgs = null, - HashSet? asyncInitServiceTypeNames = null) - { - // Open the registration lambda - WriteServiceRegistrationLambdaStart(writer, lifetime, serviceTypeName, registration.Key); - - writer.WriteLine("{"); - writer.Indentation++; - - // Resolve additional parameters using the same logic as [IocInject] methods - var paramVars = new List(additionalParameters.Length); - int paramIndex = 0; - foreach(var param in additionalParameters) - { - var paramVar = $"f_p{paramIndex}"; - var varName = ResolveParamAndEmitVar(writer, param, paramVar, isKeyedRegistration, registration.Key, asyncInitServiceTypeNames); - - // Use the resolved variable name - paramVars.Add(varName); - paramIndex++; - } - - var factoryInvocation = BuildFactoryInvocation( - factoryPath, - genericTypeArgs, - hasServiceProvider, - hasKey, - registration.Key, - paramVars, - returnTypeName, - serviceTypeName); - - writer.WriteLine($"return {factoryInvocation};"); - - writer.Indentation--; - writer.WriteLine("});"); - } - - /// - /// Builds the factory invocation expression with optional generic arguments, service provider, key, and additional parameters. - /// Adds a cast when the factory return type differs from the service type. - /// - private static string BuildFactoryInvocation( - string factoryPath, - string? genericTypeArgs, - bool hasServiceProvider, - bool hasKey, - string? registrationKey, - List additionalArgs, - string? returnTypeName, - string serviceTypeName) - { - var args = new List(additionalArgs.Count + 2); - if(hasServiceProvider) - { - args.Add("sp"); - } - - if(hasKey && registrationKey is not null) - { - args.Add(registrationKey); - } - - if(additionalArgs.Count > 0) - { - args.AddRange(additionalArgs); - } - - var factoryCallPath = genericTypeArgs is not null ? $"{factoryPath}<{genericTypeArgs}>" : factoryPath; - var factoryInvocation = $"{factoryCallPath}({string.Join(", ", args)})"; - - if(returnTypeName is not null && returnTypeName != serviceTypeName) - { - factoryInvocation = $"({serviceTypeName}){factoryInvocation}"; - } - - return factoryInvocation; - } - - /// - /// Writes a factory registration line for keyed or non-keyed services. - /// - private static void WriteFactoryRegistrationLine( - SourceWriter writer, - string lifetime, - string serviceTypeName, - string? registrationKey, - string factoryInvocation) - { - if(registrationKey is not null) - { - writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}>({registrationKey}, ({IServiceProviderGlobalTypeName} sp, object? key) => {factoryInvocation});"); - return; - } - - writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) => {factoryInvocation});"); - } - - /// - /// Writes registration code using a static instance specified in the attribute. - /// Instance registrations are only valid for Singleton lifetime. - /// - /// - /// Generates code like: - /// - /// services.AddSingleton<IMySrevice>(MyService.Default); - /// // or for keyed: - /// services.AddKeyedSingleton<IMySrevice>("key", MyService.Default); - /// - /// - private static void WriteInstanceRegistration(SourceWriter writer, ServiceRegistrationModel registration) - { - var serviceTypeName = registration.ServiceType.Name; - var instance = registration.Instance!; - - // Instance registration should only be used with Singleton - // The analyzer should catch invalid lifetime usage - if(registration.Key is not null) - { - // Keyed singleton with instance - writer.WriteLine($"services.AddKeyedSingleton<{serviceTypeName}>({registration.Key}, {instance});"); - } - else - { - // Non-keyed singleton with instance - writer.WriteLine($"services.AddSingleton<{serviceTypeName}>({instance});"); - } - } - - /// - /// Writes registration code for service types (interfaces/base classes) that forward to an already-registered implementation. - /// The implementation is already registered with its own factory method, so we just resolve it from the service provider. - /// - /// - /// Generates code like: - /// - /// services.AddTransient<IMyService>(sp => sp.GetRequiredService<MyService>()); - /// - /// - private static void WriteServiceTypeForwardingRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) - { - var serviceTypeName = registration.ServiceType.Name; - var implTypeName = registration.ImplementationType.Name; - - // When the implementation type is registered as async-init (Task), - // the forwarding registration must also use Task. - // We must use "async ... => await ..." because Task is invariant in C# β€” - // Task cannot be implicitly assigned to Task even when - // MyService : IMyService. The await unwraps the result and the async lambda - // re-wraps it as Task. - bool isAsyncInit = asyncInitServiceTypeNames?.Contains(implTypeName) == true; - if(isAsyncInit) - { - var taskServiceTypeName = $"global::System.Threading.Tasks.Task<{serviceTypeName}>"; - var taskImplTypeName = $"global::System.Threading.Tasks.Task<{implTypeName}>"; - if(registration.Key is not null) - { - var requiredCall = BuildServiceCall(GetRequiredKeyedService, taskImplTypeName, "key"); - writer.WriteLine($"services.AddKeyed{lifetime}<{taskServiceTypeName}>({registration.Key}, async ({IServiceProviderGlobalTypeName} sp, object? key) => await {requiredCall});"); - } - else - { - var requiredCall = BuildServiceCall(GetRequiredService, taskImplTypeName, serviceKey: null); - writer.WriteLine($"services.Add{lifetime}<{taskServiceTypeName}>(async ({IServiceProviderGlobalTypeName} sp) => await {requiredCall});"); - } - return; - } - - if(registration.Key is not null) - { - // Keyed registration - forward to keyed implementation - var requiredCall = BuildServiceCall(GetRequiredKeyedService, implTypeName, "key"); - writer.WriteLine($"services.AddKeyed{lifetime}<{serviceTypeName}>({registration.Key}, ({IServiceProviderGlobalTypeName} sp, object? key) => {requiredCall});"); - } - else - { - // Non-keyed registration - forward to implementation - var requiredCall = BuildServiceCall(GetRequiredService, implTypeName, serviceKey: null); - writer.WriteLine($"services.Add{lifetime}<{serviceTypeName}>(({IServiceProviderGlobalTypeName} sp) => {requiredCall});"); - } - } - - /// - /// Writes registration code for services with injection members (properties, fields, methods marked with InjectAttribute). - /// - /// - /// Generates code like: - /// - /// services.AddSingleton<MyService>((IServiceProvider sp) => - /// { - /// var s0_p0 = sp.GetRequiredService<IMayServiceDependency1>(); - /// var s0_p1 = sp.GetRequiredService<IMayServiceDependency2>(); - /// var s0_p2 = sp.GetRequiredService<IMayServiceDependency3>(); - /// var s0 = new MyService(s0_p0) { Dependency = s0_p1 }; - /// s0.Initialize(s0_p2); - /// return s0; - /// }); - /// - /// For optional parameters with default values, generates conditional logic: - /// - /// var p0 = sp.GetService<IOptionalDep>(); - /// var s0 = p0 is not null ? new MyService(optDep: p0) : new MyService(); - /// - /// - private static void WriteInjectionRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) - { - var serviceTypeName = registration.ServiceType.Name; - var implTypeName = registration.ImplementationType.Name; - var injectionMembers = registration.InjectionMembers; - - // Build the factory lambda - WriteServiceRegistrationLambdaStart(writer, lifetime, serviceTypeName, registration.Key); - - writer.WriteLine("{"); - writer.Indentation++; - - bool isKeyedRegistration = registration.Key is not null; - WriteConstructInstanceWithInjection( - writer, - instanceVarName: "s0", - implTypeName: implTypeName, - constructorParams: registration.ImplementationType.ConstructorParameters, - injectionMembers: injectionMembers, - isKeyedRegistration: isKeyedRegistration, - registrationKey: registration.Key, - serviceTypeNames: null, - ctorTypeNameResolver: null, - memberTypeNameResolver: null, - decoratedPrevVar: null, - asyncInitServiceTypeNames: asyncInitServiceTypeNames); - - writer.WriteLine("return s0;"); - - writer.Indentation--; - writer.WriteLine("});"); - } - - /// - /// Writes registration code for services that have async injection methods. - /// Generates a Task<T> registration with an async local Init() function - /// that performs construction, sync injection, and awaited async injection in order. - /// - /// - /// Generates code like: - /// - /// services.AddSingleton<Task<MyService>>((IServiceProvider sp) => - /// { - /// async Task<MyService> Init() - /// { - /// var s0_p0 = sp.GetRequiredService<ILogger>(); - /// var s0_m0 = sp.GetRequiredService<IAsyncInitializer>(); - /// var s0 = new MyService { Logger = s0_p0 }; - /// await s0.InitAsync(s0_m0); - /// return s0; - /// } - /// return Init(); - /// }); - /// - /// - private static void WriteAsyncInjectionRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) - { - var serviceTypeName = registration.ServiceType.Name; - var implTypeName = registration.ImplementationType.Name; - var injectionMembers = registration.InjectionMembers; - bool isKeyedRegistration = registration.Key is not null; - - // Service type is wrapped in Task for async-init registrations - var taskServiceTypeName = $"global::System.Threading.Tasks.Task<{serviceTypeName}>"; - var taskImplTypeName = $"global::System.Threading.Tasks.Task<{implTypeName}>"; - - // Open the registration lambda for Task - WriteServiceRegistrationLambdaStart(writer, lifetime, taskServiceTypeName, registration.Key); - - writer.WriteLine("{"); - writer.Indentation++; - - // Write the async local function header - writer.WriteLine($"async {taskImplTypeName} Init()"); - writer.WriteLine("{"); - writer.Indentation++; - - // Emit construction with sync + async injection (isAsyncMode = true) - WriteConstructInstanceWithInjection( - writer, - instanceVarName: "s0", - implTypeName: implTypeName, - constructorParams: registration.ImplementationType.ConstructorParameters, - injectionMembers: injectionMembers, - isKeyedRegistration: isKeyedRegistration, - registrationKey: registration.Key, - serviceTypeNames: null, - ctorTypeNameResolver: null, - memberTypeNameResolver: null, - decoratedPrevVar: null, - asyncInitServiceTypeNames: asyncInitServiceTypeNames, - isAsyncMode: true); - - writer.WriteLine("return s0;"); - - writer.Indentation--; - writer.WriteLine("}"); - - writer.WriteLine("return Init();"); - - writer.Indentation--; - writer.WriteLine("});"); - } - - - - /// - /// Shared helper to construct an instance and apply property/field/method injection with conditional handling. - /// Supports decorator scenarios via service-parameter detection and generic type substitution. - /// - private static void WriteConstructInstanceWithInjection( - SourceWriter writer, - string instanceVarName, - string implTypeName, - ImmutableEquatableArray? constructorParams, - ImmutableEquatableArray injectionMembers, - bool isKeyedRegistration, - string? registrationKey, - HashSet? serviceTypeNames, - Func? ctorTypeNameResolver, - Func? memberTypeNameResolver, - string? decoratedPrevVar, - HashSet? asyncInitServiceTypeNames = null, - bool isAsyncMode = false) - { - var ctorParams = constructorParams ?? []; - var constructorParamEntries = new List<(string Name, string? Value, bool NeedsConditional)>(ctorParams.Length); - int paramIndex = 0; - - foreach(var param in ctorParams) - { - var resolvedTypeName = ctorTypeNameResolver is not null ? ctorTypeNameResolver(param.Type) : param.Type.Name; - if(decoratedPrevVar is not null && serviceTypeNames is not null && IsServiceTypeParameter(param.Type, resolvedTypeName, serviceTypeNames)) - { - constructorParamEntries.Add((param.Name, decoratedPrevVar, false)); - } - else - { - var varName = decoratedPrevVar is not null ? $"{instanceVarName}_p{paramIndex}" : $"p{paramIndex}"; - var resolvedVar = ResolveParamAndEmitVar(writer, param, varName, isKeyedRegistration, registrationKey, asyncInitServiceTypeNames); - constructorParamEntries.Add((param.Name, resolvedVar, false)); - paramIndex++; - } - } - - if(constructorParamEntries.Any(e => e.NeedsConditional)) - { - var conditionalParams = constructorParamEntries.Where(e => e.NeedsConditional && e.Value is not null).ToArray(); - var initializerPart = ""; - if(conditionalParams.Length == 1) - { - var condParam = conditionalParams[0]; - var withArgs = BuildArgumentListFromEntries([.. constructorParamEntries.Select(e => (e.Name, e.NeedsConditional ? e.Value : e.Value, false))]); - var withoutArgs = BuildArgumentListFromEntries([.. constructorParamEntries.Select(e => (e.Name, e.NeedsConditional ? (string?)null : e.Value, false))]); - var withInvocation = BuildConstructorInvocation(implTypeName, withArgs, initializerPart); - var withoutInvocation = BuildConstructorInvocation(implTypeName, withoutArgs, initializerPart); - writer.WriteLine($"var {instanceVarName} = {condParam.Value} is not null ? {withInvocation} : {withoutInvocation};"); - } - else - { - var conditions = conditionalParams.Select(p => $"{p.Value} is not null").ToArray(); - var allCondition = string.Join(" && ", conditions); - writer.WriteLine($"{implTypeName} {instanceVarName};"); - writer.WriteLine($"if ({allCondition})"); - writer.WriteLine("{"); - writer.Indentation++; - var allArgs = BuildArgumentListFromEntries([.. constructorParamEntries.Select(e => (e.Name, e.Value, false))]); - var allInvocation = BuildConstructorInvocation(implTypeName, allArgs, initializerPart); - writer.WriteLine($"{instanceVarName} = {allInvocation};"); - writer.Indentation--; - writer.WriteLine("}"); - writer.WriteLine("else"); - writer.WriteLine("{"); - writer.Indentation++; - var fallbackArgs = BuildArgumentListFromEntries([.. constructorParamEntries.Select(e => (e.Name, e.NeedsConditional ? (string?)null : e.Value, false))]); - var fallbackInvocation = BuildConstructorInvocation(implTypeName, fallbackArgs, initializerPart); - writer.WriteLine($"{instanceVarName} = {fallbackInvocation};"); - writer.Indentation--; - writer.WriteLine("}"); - } - } - else - { - // Unified construction behavior for both decorator and non-decorator cases - EmitConstruction( - writer, - instanceVarName, - implTypeName, - constructorParamEntries, - injectionMembers, - isKeyedRegistration, - registrationKey, - memberTypeNameResolver, - asyncInitServiceTypeNames, - isAsyncMode); - } - - } - - private static void EmitConstruction( - SourceWriter writer, - string instanceVarName, - string implTypeName, - List<(string Name, string? Value, bool NeedsConditional)> constructorParamEntries, - ImmutableEquatableArray injectionMembers, - bool isKeyedRegistration, - string? registrationKey, - Func? memberTypeNameResolver, - HashSet? asyncInitServiceTypeNames = null, - bool isAsyncMode = false) - { - // Resolve property/field injection parameters - int pfCount = injectionMembers.Count(m => m.MemberType is InjectionMemberType.Property or InjectionMemberType.Field); - var preProps = pfCount > 0 ? new (string Name, string ParamVar)[pfCount] : []; - int idx = 0; - int pfIdxCounter = 0; - foreach(var m in injectionMembers) - { - if(m.MemberType is not (InjectionMemberType.Property or InjectionMemberType.Field)) - continue; - var varN = $"{instanceVarName}_p{pfIdxCounter}"; - var mt = m.Type; - var mtName = mt is null ? "object" : (memberTypeNameResolver is not null ? memberTypeNameResolver(mt) : mt.Name); - bool hasNonNullDefault = m.HasDefaultValue && !m.DefaultValueIsNull; - ResolveMemberValue(writer, mt, mtName, varN, m.Key, m.IsNullable, hasNonNullDefault, m.DefaultValue, asyncInitServiceTypeNames); - preProps[idx++] = (m.Name, varN); - pfIdxCounter++; - } - - // Resolve sync method parameters before instance creation - int memberParamIndex = pfIdxCounter; - var methodParamResolutions = ResolveMethodParamResolutions( - writer, injectionMembers, InjectionMemberType.Method, instanceVarName, ref memberParamIndex, - isKeyedRegistration, registrationKey, memberTypeNameResolver, asyncInitServiceTypeNames); - - // Resolve async method parameters before instance creation (only when in async mode) - var asyncMethodParamResolutions = isAsyncMode - ? ResolveMethodParamResolutions( - writer, injectionMembers, InjectionMemberType.AsyncMethod, instanceVarName, ref memberParamIndex, - isKeyedRegistration, registrationKey, memberTypeNameResolver, asyncInitServiceTypeNames) - : []; - - // Property/field values are resolved upfront; include all properties in the object initializer - var propertyInits = preProps.Select(p => $"{p.Name} = {p.ParamVar}").ToArray(); - var constructorArgs = BuildArgumentListFromEntries([.. constructorParamEntries]); - var initializerPart = propertyInits.Length > 0 ? $" {{ {string.Join(", ", propertyInits)} }}" : ""; - var constructorInvocation = BuildConstructorInvocation(implTypeName, constructorArgs, initializerPart); - writer.WriteLine($"var {instanceVarName} = {constructorInvocation};"); - - // Call sync methods - EmitMethodInvocations(writer, instanceVarName, methodParamResolutions, useAwait: false); - - // Await async methods (only in async mode) - EmitMethodInvocations(writer, instanceVarName, asyncMethodParamResolutions, useAwait: true); - } - - /// - /// Builds a constructor invocation expression with an optional initializer. - /// - private static string BuildConstructorInvocation(string implTypeName, string args, string initializerPart) => - $"new {implTypeName}({args}){initializerPart}"; - - /// - /// Resolves method parameters of a given and emits their variable declarations. - /// Shared by sync () and async () resolution loops. - /// - private static List<(string MethodName, string?[] ParamVars, string[] ParamNames)> ResolveMethodParamResolutions( - SourceWriter writer, - ImmutableEquatableArray injectionMembers, - InjectionMemberType targetType, - string instanceVarName, - ref int memberParamIndex, - bool isKeyedRegistration, - string? registrationKey, - Func? memberTypeNameResolver, - HashSet? asyncInitServiceTypeNames) - { - var resolutions = new List<(string MethodName, string?[] ParamVars, string[] ParamNames)>(); - foreach(var method in injectionMembers) - { - if(method.MemberType != targetType) - continue; - var mParams = method.Parameters ?? []; - var mVars = new string?[mParams.Length]; - var mNames = new string[mParams.Length]; - int mi = 0; - foreach(var p in mParams) - { - var pVar = $"{instanceVarName}_m{memberParamIndex}"; - mNames[mi] = p.Name; - if(method.Key is not null) - { - mVars[mi] = ResolveMethodParameterWithKey( - writer, - p, - pVar, - method.Key, - isKeyedRegistration, - registrationKey, - memberTypeNameResolver, - asyncInitServiceTypeNames); - } - else - { - mVars[mi] = ResolveParamAndEmitVar(writer, p, pVar, isKeyedRegistration, registrationKey, asyncInitServiceTypeNames); - } - mi++; - memberParamIndex++; - } - resolutions.Add((method.Name, mVars, mNames)); - } - return resolutions; - } - - /// - /// Emits method invocation statements for the given . - /// When is , each call is prefixed with await. - /// - private static void EmitMethodInvocations( - SourceWriter writer, - string instanceVarName, - List<(string MethodName, string?[] ParamVars, string[] ParamNames)> resolutions, - bool useAwait) - { - foreach(var (mName, mVars, mNames) in resolutions) - { - var entries = new List<(string Name, string? Value, bool NeedsConditional)>(mVars.Length); - for(int i = 0; i < mVars.Length; i++) - entries.Add((mNames[i], mVars[i], false)); - var args = BuildArgumentListFromEntries([.. entries]); - writer.WriteLine(useAwait - ? $"await {instanceVarName}.{mName}({args});" - : $"{instanceVarName}.{mName}({args});"); - } - } - - /// - /// Resolves a property or field injection value and emits its variable declaration. - /// - private static void ResolveMemberValue( - SourceWriter writer, - TypeData? memberType, - string memberTypeName, - string paramVar, - string? serviceKey, - bool isNullable, - bool hasNonNullDefault, - string? defaultValue, - HashSet? asyncInitServiceTypeNames = null) - { - if(memberType is CollectionWrapperTypeData) - { - WriteCollectionResolution(writer, memberType, paramVar, serviceKey, isOptional: isNullable); - return; - } - - if(memberType is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData or TaskTypeData) - { - WriteWrapperResolution(writer, memberType, paramVar, serviceKey, isOptional: isNullable, asyncInitServiceTypeNames); - return; - } - - if(hasNonNullDefault) - { - var defExpr = defaultValue ?? "default"; - var svcCall = BuildServiceCall(serviceKey is not null ? GetKeyedService : GetService, memberTypeName, serviceKey); - writer.WriteLine($"var {paramVar} = {svcCall} ?? {defExpr};"); - return; - } - - var methodName = isNullable - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - var call = BuildServiceCall(methodName, memberTypeName, serviceKey); - writer.WriteLine($"var {paramVar} = {call};"); - } - - /// - /// Resolves a keyed method parameter and emits its variable declaration. - /// - private static string? ResolveMethodParameterWithKey( - SourceWriter writer, - ParameterData param, - string paramVar, - string methodKey, - bool isKeyedRegistration, - string? registrationKey, - Func? typeNameResolver, - HashSet? asyncInitServiceTypeNames = null) - { - if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, methodKey, isOptional: false, typeNameResolver, asyncInitServiceTypeNames, out var resolvedVar)) - { - return resolvedVar!; - } - - var resolvedTypeName = typeNameResolver is not null ? typeNameResolver(param.Type) : param.Type.Name; - if(param.HasDefaultValue) - { - var defExpr = param.DefaultValue ?? "default"; - var svcCall = BuildServiceCall(GetKeyedService, resolvedTypeName, methodKey); - writer.WriteLine($"var {paramVar} = {svcCall} ?? {defExpr};"); - return paramVar; - } - - var requiredCall = BuildServiceCall(GetRequiredKeyedService, resolvedTypeName, methodKey); - writer.WriteLine($"var {paramVar} = {requiredCall};"); - return paramVar; - } - - - - /// - /// Builds an argument list from entries with conditional flag. - /// - private static string BuildArgumentListFromEntries((string Name, string? Value, bool NeedsConditional)[] entries) - { - // Check if any parameter uses default value - bool hasDefaultValue = entries.Any(e => e.Value is null); - - if(!hasDefaultValue) - { - // All values present - use positional arguments - return string.Join(", ", entries.Select(e => e.Value!)); - } - - // Some values are null - use named arguments for non-null values only - var namedArgs = entries.Where(e => e.Value is not null).Select(e => $"{e.Name}: {e.Value}"); - return string.Join(", ", namedArgs); - } - - /// - /// Converts a TypeData to typeof() syntax for open generic types. - /// For example: TypeData with Name="global::Namespace.GenericTest<T>" becomes "typeof(global::Namespace.GenericTest<>)" - /// - private static string ConvertToTypeOf(TypeData typeData) - { - if(typeData is not GenericTypeData genericTypeData || genericTypeData.GenericArity == 0) - { - return $"typeof({typeData.Name})"; - } - - // Build the open generic typeof - return $"typeof({genericTypeData.NameWithoutGeneric}{GetGenericString(genericTypeData.GenericArity)})"; - } - - /// - /// Cached generic arity strings to avoid repeated allocations. - /// - private static readonly string[] s_genericArityStrings = - [ - "<>", - "<,>", - "<,,>", - "<,,,>", - "<,,,,>", - "<,,,,,>", - "<,,,,,,>", - "<,,,,,,,>", - "<,,,,,,,,>" - ]; - - private static string GetGenericString(in int arity) => - arity <= 9 ? s_genericArityStrings[arity - 1] : '<' + new string(',', arity - 1) + '>'; - - /// - /// Writes decorator pattern registration code. - /// - /// - /// Generates code like: - /// - /// services.AddSingleton<IMyService>((IServiceProvider sp) => - /// { - /// var s0 = sp.GetRequiredService<MyService>(); - /// var s1_p0 = sp.GetRequiredService<ILogger<MyServiceDecorator2>>(); - /// var s1 = new MyServiceDecorator2(s1_p0, s0); - /// var s2_p0 = sp.GetRequiredService<ILogger<MyServiceDecorator>>(); - /// var s2 = new MyServiceDecorator(s2_p0, s1); - /// return s2; - /// }); - /// - /// For open generic decorators, falls back to ActivatorUtilities.CreateInstance. - /// - private static void WriteDecoratorRegistration(SourceWriter writer, ServiceRegistrationModel registration, string lifetime, HashSet? asyncInitServiceTypeNames = null) - { - var decorators = registration.Decorators; - // Decorators array is in order from outermost to innermost, - // we iterate in reverse order for building the chain from inner to outer - var decoratorCount = decorators.Length; - - // Get generic arguments from the service type (for closed generic service types) - var serviceTypeParams = registration.ServiceType is GenericTypeData genericServiceType - ? genericServiceType.TypeParameters - : null; - var serviceTypeName = registration.ServiceType.Name; - - // Build service type names set for IsServiceParameter check - var serviceTypeNames = BuildServiceTypeNames(registration); - - WriteServiceRegistrationLambdaStart(writer, lifetime, serviceTypeName, registration.Key); - - writer.WriteLine("{"); - writer.Indentation++; - - // s0 = implementation instance - if(registration.Key is not null) - { - writer.WriteLine($"var s0 = sp.GetRequiredKeyedService<{registration.ImplementationType.Name}>(key);"); - } - else - { - writer.WriteLine($"var s0 = sp.GetRequiredService<{registration.ImplementationType.Name}>();"); - } - - // Build decorator chain (iterate in reverse order: from innermost to outermost) - for(int i = 0; i < decoratorCount; i++) - { - var decorator = decorators[decoratorCount - 1 - i]; - var prevVar = $"s{i}"; - var currentVar = $"s{i + 1}"; - - var decoratorTypeName = GetClosedDecoratorTypeName(decorator, serviceTypeParams); - var ctorParams = decorator.ConstructorParameters; - if(ctorParams is null || ctorParams.Length == 0) - { - writer.WriteLine($"var {currentVar} = global::Microsoft.Extensions.DependencyInjection.ActivatorUtilities.CreateInstance<{decoratorTypeName}>(sp, {prevVar});"); - continue; - } - - // Use shared helper to construct decorator with Option A conditional handling and service parameter wiring - bool isKeyedRegistration = registration.Key is not null; - WriteConstructInstanceWithInjection( - writer, - instanceVarName: currentVar, - implTypeName: decoratorTypeName, - constructorParams: ctorParams, - injectionMembers: decorator.InjectionMembers ?? [], - isKeyedRegistration: isKeyedRegistration, - registrationKey: registration.Key, - serviceTypeNames: serviceTypeNames, - ctorTypeNameResolver: t => decorator is GenericTypeData { IsOpenGeneric: true } && serviceTypeParams is not null ? SubstituteGenericArguments(t, decorator, serviceTypeParams) : t.Name, - memberTypeNameResolver: t => decorator is GenericTypeData { IsOpenGeneric: true } && serviceTypeParams is not null ? SubstituteGenericArguments(t, decorator, serviceTypeParams) : t.Name, - decoratedPrevVar: prevVar, - asyncInitServiceTypeNames: asyncInitServiceTypeNames); - } - - // Return the outermost decorator - writer.WriteLine($"return s{decoratorCount};"); - - writer.Indentation--; - writer.WriteLine("});"); - } - - /// - /// Gets the closed decorator type name by substituting generic arguments if the decorator is an open generic. - /// - /// The decorator type data. - /// The type parameters from the service type. - /// The closed decorator type name. - private static string GetClosedDecoratorTypeName(TypeData decorator, ImmutableEquatableArray? serviceTypeParams) - { - // If decorator is not open generic, return as-is - if(decorator is not GenericTypeData { IsOpenGeneric: true } genericDecorator) - { - return decorator.Name; - } - - // If no generic arguments available from service type, return as-is (this shouldn't happen in valid scenarios) - if(serviceTypeParams is null || serviceTypeParams.Length == 0) - { - return decorator.Name; - } - - // Replace the open generic parameters with the service type's generic arguments - // e.g., "global::Namespace.Decorator" -> "global::Namespace.Decorator" - return $"{genericDecorator.NameWithoutGeneric}<{string.Join(", ", serviceTypeParams.Select(a => a.Type.Name))}>"; - } - - /// - /// Substitutes generic type parameters in a parameter type with actual generic arguments. - /// - /// The parameter type data that may contain type parameters. - /// The decorator type data containing type parameter information. - /// The type parameters from the service type. - /// The type name with substituted generic arguments. - private static string SubstituteGenericArguments(TypeData paramType, TypeData decorator, ImmutableEquatableArray serviceTypeParams) - { - // Use pre-extracted type arguments from TypeData - var decoratorTypeParams = decorator is GenericTypeData genericDecorator - ? genericDecorator.TypeParameters - : null; - if(decoratorTypeParams is null || decoratorTypeParams.Length == 0) - { - return paramType.Name; - } - - if(serviceTypeParams.Length != decoratorTypeParams.Length) - { - // Mismatch in arity, return original - return paramType.Name; - } - - // Build replacement map and apply substitutions - var result = paramType.Name; - for(int i = 0; i < decoratorTypeParams.Length; i++) - { - // Replace type parameter with actual argument from service type - // Need to be careful to replace whole type names, not partial matches - result = ReplaceTypeParameter(result, decoratorTypeParams[i].ParameterName, serviceTypeParams[i].Type.Name); - } - - return result; - } - - /// - /// Builds a set of service type names for IsServiceParameter check. - /// Includes the service type, implementation type, all base classes, and all interfaces. - /// This ensures decorator parameters that reference any service type are correctly identified. - /// - private static HashSet BuildServiceTypeNames(ServiceRegistrationModel registration) - { - var serviceTypeNames = new HashSet(StringComparer.Ordinal); - - // Add service type variants - AddTypeNameVariants(serviceTypeNames, registration.ServiceType); - - // Add implementation type variants - AddTypeNameVariants(serviceTypeNames, registration.ImplementationType); - - // Add all base classes (important for decorator parameter matching) - if(registration.ImplementationType.AllBaseClasses is not null) - { - foreach(var baseClass in registration.ImplementationType.AllBaseClasses) - { - AddTypeNameVariants(serviceTypeNames, baseClass); - } - } - - // Add all interfaces (important when multiple service types exist) - if(registration.ImplementationType.AllInterfaces is not null) - { - foreach(var iface in registration.ImplementationType.AllInterfaces) - { - AddTypeNameVariants(serviceTypeNames, iface); - } - } - - return serviceTypeNames; - } - - /// - /// Adds both the full name and non-generic name variants to the set. - /// - private static void AddTypeNameVariants(HashSet set, TypeData type) - { - set.Add(type.Name); - if(type is GenericTypeData genericType && type.Name != genericType.NameWithoutGeneric) - { - set.Add(genericType.NameWithoutGeneric); - } - } - - /// - /// Checks if a parameter type matches any of the service types. - /// - private static bool IsServiceTypeParameter(TypeData paramType, string substitutedTypeName, HashSet serviceTypeNames) - { - // Check substituted type name first (for open generic decorators) - if(serviceTypeNames.Contains(substitutedTypeName)) - { - return true; - } - - // Direct match on full name or non-generic name - if(serviceTypeNames.Contains(paramType.Name)) - { - return true; - } - - return paramType is GenericTypeData genericParamType - && serviceTypeNames.Contains(genericParamType.NameWithoutGeneric); - } - - /// - /// Writes parameter resolution code for [ServiceKey] attribute. - /// When the service is registered as keyed, injects the registration key with appropriate type casting. - /// When the service is not keyed, injects null. - /// - /// The source writer. - /// The variable name for this parameter. - /// The resolved parameter type name. - /// Whether this is a keyed service registration. - /// The registration key for keyed services. - /// The variable name or expression to use for this parameter. - private static string WriteServiceKeyParameterResolution( - SourceWriter writer, - string paramVar, - string paramTypeName, - bool isKeyedRegistration, - string? registrationKey) - { - if(isKeyedRegistration && registrationKey is not null) - { - writer.WriteLine($"var {paramVar} = {registrationKey};"); - } - else - { - // Non-keyed registration: use the default value for the key type - writer.WriteLine($"var {paramVar} = default({paramTypeName});"); - } - return paramVar; - } - - /// - /// Resolve a parameter and emit its variable declaration. - /// Produces lines like: - /// var p = sp.GetService() ?? ; - /// var p = sp.GetKeyedService(key) ?? ; - /// or returns "sp" for IServiceProvider parameters (no var emitted). - /// - private static string ResolveParamAndEmitVar( - SourceWriter writer, - ParameterData param, - string paramVar, - bool isKeyedRegistration, - string? registrationKey = null, - HashSet? asyncInitServiceTypeNames = null) - { - var paramTypeName = param.Type.Name; - - if(TryResolveCommonParameter(writer, param, paramVar, isKeyedRegistration, registrationKey, param.ServiceKey, param.IsOptional, typeNameResolver: null, asyncInitServiceTypeNames, out var resolvedVar)) - { - return resolvedVar!; - } - - // General case: if the parameter is optional (HasDefaultValue or IsOptional), use GetService + ?? default; otherwise use GetRequiredService - if(param.HasDefaultValue || param.IsOptional) - { - var methodName = param.ServiceKey is not null ? GetKeyedService : GetService; - var svcCall = BuildServiceCall(methodName, paramTypeName, param.ServiceKey); - var defExpr = param.HasDefaultValue ? (param.DefaultValue ?? "default") : "default"; - writer.WriteLine($"var {paramVar} = {svcCall} ?? {defExpr};"); - return paramVar; - } - else - { - var methodName = param.ServiceKey is not null ? GetRequiredKeyedService : GetRequiredService; - var svcCall = BuildServiceCall(methodName, paramTypeName, param.ServiceKey); - writer.WriteLine($"var {paramVar} = {svcCall};"); - return paramVar; - } - } - - /// - /// Writes collection resolution code for constructor parameters and injection members. - /// Handles all collection types including IEnumerable<T>, IList<T>, T[], IReadOnlyList<T>, etc. - /// - /// The source writer. - /// The type data of the collection. - /// The variable name for this parameter. - /// The service key (null for non-keyed services). - /// Whether the parameter is optional (uses GetService instead of GetRequiredService). - private static void WriteCollectionResolution( - SourceWriter writer, - TypeData type, - string paramVar, - string? serviceKey, - bool isOptional = false) - { - var isKeyed = serviceKey is not null; - - // Extract element type name based on collection type - var elementTypeName = type switch - { - EnumerableTypeData e => e.ElementType.Name, - ReadOnlyCollectionTypeData r => r.ElementType.Name, - CollectionTypeData c => c.ElementType.Name, - ReadOnlyListTypeData rl => rl.ElementType.Name, - ListTypeData l => l.ElementType.Name, - ArrayTypeData a => a.ElementType.Name, - _ => null - }; - - if(elementTypeName is null) - { - // Fallback: not a collection type, use direct service resolution - if(isKeyed) - { - var method = isOptional ? GetKeyedService : GetRequiredKeyedService; - var call = BuildServiceCall(method, type.Name, serviceKey); - writer.WriteLine($"var {paramVar} = {call};"); - } - else - { - var method = isOptional ? GetService : GetRequiredService; - var call = BuildServiceCall(method, type.Name, serviceKey: null); - writer.WriteLine($"var {paramVar} = {call};"); - } - return; - } - - // Use WrapperKind to determine the resolution method - switch(((WrapperTypeData)type).WrapperKind) - { - case WrapperKind.Enumerable: - // IEnumerable - use GetServices directly - if(isKeyed) - { - var call = BuildServiceCall(GetKeyedServices, elementTypeName, serviceKey); - writer.WriteLine($"var {paramVar} = {call};"); - } - else - { - var call = BuildServiceCall(GetServices, elementTypeName, serviceKey: null); - writer.WriteLine($"var {paramVar} = {call};"); - } - break; - - case WrapperKind.ReadOnlyCollection: - case WrapperKind.Collection: - case WrapperKind.ReadOnlyList: - case WrapperKind.List: - case WrapperKind.Array: - // IReadOnlyCollection, ICollection, IReadOnlyList, IList, T[] - resolve as array - if(isKeyed) - { - var call = BuildServiceCall(GetKeyedServices, elementTypeName, serviceKey); - writer.WriteLine($"var {paramVar} = {call}.ToArray();"); - } - else - { - var call = BuildServiceCall(GetServices, elementTypeName, serviceKey: null); - writer.WriteLine($"var {paramVar} = {call}.ToArray();"); - } - break; - - default: - // Fallback for unknown collection types - if(isKeyed) - { - var method = isOptional ? GetKeyedService : GetRequiredKeyedService; - var call = BuildServiceCall(method, type.Name, serviceKey); - writer.WriteLine($"var {paramVar} = {call};"); - } - else - { - var method = isOptional ? GetService : GetRequiredService; - var call = BuildServiceCall(method, type.Name, serviceKey: null); - writer.WriteLine($"var {paramVar} = {call};"); - } - break; - } - } - - /// - /// Writes wrapper type resolution code for Lazy<T>, Func<T>, IDictionary<TKey, TValue>, - /// and KeyValuePair<TKey, TValue>. Supports nested wrapper types (e.g., Lazy<KeyValuePair<K, V>>). - /// - /// The source writer. - /// The wrapper type data. - /// The variable name for this parameter. - /// The service key (null for non-keyed services). - /// Whether the parameter is optional. - private static void WriteWrapperResolution( - SourceWriter writer, - TypeData type, - string paramVar, - string? serviceKey, - bool isOptional = false, - HashSet? asyncInitServiceTypeNames = null) - { - var expr = BuildWrapperExpression(type, serviceKey, isOptional, asyncInitServiceTypeNames); - writer.WriteLine($"var {paramVar} = {expr};"); - } - - /// - /// Builds an inline wrapper expression. Recursively handles nested wrappers. - /// - /// The wrapper type data to build an expression for. - /// The service key (null for non-keyed services). - /// Whether this resolution is optional. - /// A C# expression string that resolves the wrapper type. - private static string BuildWrapperExpression(TypeData type, string? serviceKey, bool isOptional, HashSet? asyncInitServiceTypeNames = null) - { - switch(type) - { - case LazyTypeData lazy: - { - var innerType = lazy.InstanceType; - // Direct Lazy where T is not a wrapper β€” resolve from DI (standalone registration exists) - if(innerType is not WrapperTypeData) - { - var methodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - return BuildServiceCall(methodName, type.Name, serviceKey); - } - // Nested wrapper (e.g., Lazy>) β€” inline construction - var lazyInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); - return $"new global::System.Lazy<{innerType.Name}>(() => {lazyInnerExpr}, global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication)"; - } - - case FuncTypeData func: - { - var innerType = func.ReturnType; - if(func.HasInputParameters) - { - var methodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - return BuildServiceCall(methodName, type.Name, serviceKey); - } - - // Direct Func where T is not a wrapper β€” resolve from DI (standalone registration exists) - if(innerType is not WrapperTypeData) - { - var methodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - return BuildServiceCall(methodName, type.Name, serviceKey); - } - // Nested wrapper (e.g., Func>) β€” inline construction - var funcInnerExpr = BuildInnerResolutionExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); - return $"new global::System.Func<{innerType.Name}>(() => {funcInnerExpr})"; - } - - case KeyValuePairTypeData kvp: - { - var keyType = kvp.KeyType; - var valueType = kvp.ValueType; - // KeyValuePair uses the registration's service key as the key value - var keyExpr = serviceKey ?? "default"; - var valueExpr = BuildInnerResolutionExpression(valueType, serviceKey, isOptional, asyncInitServiceTypeNames); - return $"new global::System.Collections.Generic.KeyValuePair<{keyType.Name}, {valueType.Name}>({keyExpr}, {valueExpr})"; - } - - case DictionaryTypeData dict: - { - // Dictionary resolution: collect all KeyValuePair services and convert to dictionary. - var keyType = dict.KeyType; - var valueType = dict.ValueType; - var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{keyType.Name}, {valueType.Name}>"; - var getServicesCall = BuildServiceCall( - serviceKey is not null ? GetKeyedServices : GetServices, - kvpTypeName, - serviceKey); - return $"{getServicesCall}.ToDictionary()"; - } - - case TaskTypeData task: - { - // Task wrapper: if inner type is an async-init service, resolve Task directly; - // otherwise wrap synchronous resolution in Task.FromResult(...). - var innerTypeName = task.InnerType.Name; - if(asyncInitServiceTypeNames?.Contains(innerTypeName) == true) - { - // Async-init service: Task is registered directly. - var methodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - return BuildServiceCall(methodName, type.Name, serviceKey); - } - else - { - // Sync-only service: wrap with Task.FromResult. - var syncMethodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - var syncCall = BuildServiceCall(syncMethodName, innerTypeName, serviceKey); - return $"global::System.Threading.Tasks.Task.FromResult({syncCall})"; - } - } - - default: - { - // Fallback for non-wrapper types - var methodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - return BuildServiceCall(methodName, type.Name, serviceKey); - } - } - } - - /// - /// Builds an inner resolution expression β€” either a nested wrapper expression, a collection - /// expression, or a direct service call. Supports nesting such as Lazy<IEnumerable<T>>. - /// - private static string BuildInnerResolutionExpression(TypeData innerType, string? serviceKey, bool isOptional, HashSet? asyncInitServiceTypeNames = null) - { - // If the inner type is itself a wrapper, recurse to handle nesting (e.g., Lazy>) - if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) - { - return BuildWrapperExpression(innerType, serviceKey, isOptional, asyncInitServiceTypeNames); - } - - // Collection types inside wrappers (e.g., Lazy>) - if(innerType is CollectionWrapperTypeData collectionInner) - { - var elementTypeName = (string?)collectionInner.ElementType.Name; - - if(elementTypeName is not null) - { - var getServicesCall = BuildServiceCall( - serviceKey is not null ? GetKeyedServices : GetServices, - elementTypeName, - serviceKey); - return innerType is CollectionWrapperTypeData and not EnumerableTypeData - ? $"{getServicesCall}.ToArray()" - : getServicesCall; - } - } - - // Otherwise, use direct service resolution - var methodName = isOptional - ? (serviceKey is not null ? GetKeyedService : GetService) - : (serviceKey is not null ? GetRequiredKeyedService : GetRequiredService); - return BuildServiceCall(methodName, innerType.Name, serviceKey); - } - - /// - /// Builds a service resolution call for keyed or non-keyed services. - /// - private static string BuildServiceCall(string methodName, string typeName, string? serviceKey) => - serviceKey is not null - ? $"sp.{methodName}<{typeName}>({serviceKey})" - : $"sp.{methodName}<{typeName}>()"; - - /// - /// Attempts to resolve common parameter cases and emit any required variable declarations. - /// Returns true when a resolution was produced. - /// - private static bool TryResolveCommonParameter( - SourceWriter writer, - ParameterData param, - string paramVar, - bool isKeyedRegistration, - string? registrationKey, - string? serviceKey, - bool isOptional, - Func? typeNameResolver, - HashSet? asyncInitServiceTypeNames, - out string? resolvedVar) - { - if(IsServiceProviderType(param.Type.Name)) - { - resolvedVar = "sp"; - return true; - } - - var resolvedTypeName = typeNameResolver is not null ? typeNameResolver(param.Type) : param.Type.Name; - if(param.HasServiceKeyAttribute) - { - resolvedVar = WriteServiceKeyParameterResolution(writer, paramVar, resolvedTypeName, isKeyedRegistration, registrationKey); - return true; - } - - if(param.Type is CollectionWrapperTypeData) - { - WriteCollectionResolution(writer, param.Type, paramVar, serviceKey, isOptional); - resolvedVar = paramVar; - return true; - } - - if(param.Type is LazyTypeData or FuncTypeData or DictionaryTypeData or KeyValuePairTypeData or TaskTypeData) - { - WriteWrapperResolution(writer, param.Type, paramVar, serviceKey, isOptional, asyncInitServiceTypeNames); - resolvedVar = paramVar; - return true; - } - - resolvedVar = null; - return false; - } - - /// - /// Checks if the type name represents System.IServiceProvider. - /// - private static bool IsServiceProviderType(string typeName) => - typeName is IServiceProviderGlobalTypeName or IServiceProviderTypeName or "IServiceProvider"; - - /// - /// Builds the generic type arguments string for a generic factory method. - /// Uses the to map placeholder types in the service type template - /// to the actual types from the closed service type. - /// - /// The factory method data containing the generic type mapping. - /// The closed service type to extract type arguments from. - /// The generic type arguments string (e.g., "Entity, Dto"), or null if not a generic factory. - /// - /// Given: - /// - ServiceTypeTemplate: IRequestHandler<Task<int>> - /// - PlaceholderToTypeParamMap: { "int" -> 0 } - /// - ClosedServiceType: IRequestHandler<Task<Entity>> - /// Returns: "Entity" - /// - private static string? BuildGenericFactoryTypeArgs(FactoryMethodData factory, TypeData closedServiceType) - { - var mapping = factory.GenericTypeMapping; - if(mapping is null || factory.TypeParameterCount == 0) - { - return null; - } - - var template = mapping.ServiceTypeTemplate; - var placeholderMap = mapping.PlaceholderToTypeParamMap; - - // Build a map from placeholder types to actual types by comparing template with closed service type - var placeholderToActualType = new Dictionary(StringComparer.Ordinal); - ExtractPlaceholderMappings(template, closedServiceType, placeholderToActualType); - - // Build type arguments array in the order of factory method's type parameters - var typeArgs = new string[factory.TypeParameterCount]; - foreach(var kvp in placeholderMap) - { - var placeholderTypeName = kvp.Key; - var typeParamIndex = kvp.Value; - - if(typeParamIndex < 0 || typeParamIndex >= typeArgs.Length) - { - continue; - } - - if(placeholderToActualType.TryGetValue(placeholderTypeName, out var actualTypeName)) - { - typeArgs[typeParamIndex] = actualTypeName; - } - } - - // Validate all type arguments are filled - foreach(var arg in typeArgs) - { - if(string.IsNullOrEmpty(arg)) - { - return null; // Missing type argument, cannot generate generic call - } - } - - return string.Join(", ", typeArgs); - } - - /// - /// Recursively extracts mappings from placeholder types in the template to actual types in the closed type. - /// - private static void ExtractPlaceholderMappings( - TypeData template, - TypeData closed, - Dictionary placeholderToActualType) - { - if(template is not GenericTypeData genericTemplate || closed is not GenericTypeData genericClosed) - { - return; - } - - // If base type names don't match, can't extract mappings - if(genericTemplate.NameWithoutGeneric != genericClosed.NameWithoutGeneric) - { - return; - } - - var templateParams = genericTemplate.TypeParameters; - var closedParams = genericClosed.TypeParameters; - - if(templateParams is null || closedParams is null) - { - return; - } - - if(templateParams.Length != closedParams.Length) - { - return; - } - - for(int i = 0; i < templateParams.Length; i++) - { - var templateParamType = templateParams[i].Type; - var closedParamType = closedParams[i].Type; - - // If the template param is a simple type (no nested type parameters), - // it's a placeholder that maps to the closed param type - if(templateParamType is not GenericTypeData { TypeParameters.Length: > 0 }) - { - // Map the template type name to the closed type name - // e.g., "global::System.Int32" -> "global::TestNamespace.Entity" - placeholderToActualType[templateParamType.Name] = closedParamType.Name; - } - else - { - // Nested generic type, recurse - ExtractPlaceholderMappings(templateParamType, closedParamType, placeholderToActualType); - } - } - } -} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs deleted file mode 100644 index b2041b5..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/GroupRegistrationsForContainer.cs +++ /dev/null @@ -1,388 +0,0 @@ -ο»Ώnamespace SourceGen.Ioc; - -partial class IocSourceGenerator -{ - /// - /// Transforms container and registrations into grouped data for code generation. - /// This step is separated from output generation to enable incremental generator caching. - /// - private static ContainerWithGroups GroupRegistrationsForContainer( - ContainerModel container, - ImmutableEquatableArray allRegistrations) - { - // Filter registrations based on ExplicitOnly mode - var registrations = FilterRegistrationsForContainer(container, allRegistrations); - - // Collect partial accessor method names to avoid naming conflicts - var reservedNames = new HashSet(StringComparer.Ordinal); - foreach(var accessor in container.PartialAccessors) - { - if(accessor.Kind == PartialAccessorKind.Method) - { - reservedNames.Add(accessor.Name); - } - } - - // Group registrations for code generation - var groups = BuildContainerRegistrationGroups(registrations, container.EagerResolveOptions, reservedNames); - - return new ContainerWithGroups(container, groups); - } - - /// - /// Filters registrations based on the container's ExplicitOnly and IncludeTags settings. - /// Priority: ExplicitOnly > IncludeTags > All registrations. - /// - private static ImmutableEquatableArray FilterRegistrationsForContainer( - ContainerModel container, - ImmutableEquatableArray allRegistrations) - { - // ExplicitOnly takes precedence over IncludeTags - if(container.ExplicitOnly) - { - // Only include explicit registrations from the container class - var builder = new List(); - - foreach(var explicitReg in container.ExplicitRegistrations) - { - // Convert RegistrationData to ServiceRegistrationModel - // For explicit registrations, we process them with default settings - var processed = ProcessExplicitRegistrationForContainer(explicitReg); - if(processed is not null) - { - builder.Add(processed); - } - } - - return builder.ToImmutableEquatableArray(); - } - - // Apply IncludeTags filtering if specified - if(container.IncludeTags.Length > 0) - { - // Only include services that have at least one matching tag - return allRegistrations - .Where(r => r.Tags.Length > 0 && r.Tags.Any(tag => container.IncludeTags.Contains(tag, StringComparer.Ordinal))) - .Select(static r => r.Registration) - .ToImmutableEquatableArray(); - } - - // Include all registrations from the assembly - return allRegistrations - .Select(static r => r.Registration) - .ToImmutableEquatableArray(); - } - - /// - /// Converts explicit RegistrationData to ServiceRegistrationModel for container generation. - /// - private static ServiceRegistrationModel? ProcessExplicitRegistrationForContainer(RegistrationData data) - { - // Get the first service type, or use implementation type - var serviceType = data.ServiceTypes.Length > 0 - ? data.ServiceTypes[0] - : data.ImplementationType; - - return new ServiceRegistrationModel( - serviceType, - data.ImplementationType, - data.Lifetime, - data.Key, - data.KeyType, - data.KeyValueType, - data.ImplementationType is GenericTypeData { IsOpenGeneric: true }, - data.Decorators, - data.InjectionMembers, - data.Factory, - data.Instance); - } - - /// - /// Groups registrations by service type and key for efficient lookup and collection resolution. - /// Pre-computes field names, method names, and disposal lists to avoid redundant calculations. - /// - private static ContainerRegistrationGroups BuildContainerRegistrationGroups( - ImmutableEquatableArray registrations, - EagerResolveOptions eagerResolveOptions, - HashSet reservedNames) - { - // Group by (ServiceType.Name, Key) for efficient lookup - var byServiceTypeAndKey = new Dictionary<(string ServiceType, string? Key), List>(); - - // Track all unique service types for IsService checks - var allServiceTypes = new HashSet(); - - // Track unique implementations per lifetime using Dictionary instead of List + index tracking. - // Key: (ImplementationName, ServiceKey, InstanceOrFactory), Value: (CachedRegistration, HasDecorators) - var singletonMap = new Dictionary<(string ImplName, string? Key, string? InstanceOrFactory), (CachedRegistration Cached, bool HasDecorators)>(); - var scopedMap = new Dictionary<(string ImplName, string? Key, string? InstanceOrFactory), (CachedRegistration Cached, bool HasDecorators)>(); - var transientMap = new Dictionary<(string ImplName, string? Key, string? InstanceOrFactory), (CachedRegistration Cached, bool HasDecorators)>(); - - var hasOpenGenerics = false; - var hasKeyedServices = false; - - foreach(var reg in registrations) - { - if(reg.IsOpenGeneric) - { - hasOpenGenerics = true; - // Skip open generics for most processing but track the flag - continue; - } - - // Pre-compute field and method names once, including IsEager flag - var cached = CreateCachedRegistration(reg, eagerResolveOptions, reservedNames); - - var key = (reg.ServiceType.Name, reg.Key); - if(reg.Key is not null) - { - hasKeyedServices = true; - } - - if(!byServiceTypeAndKey.TryGetValue(key, out var list)) - { - list = []; - byServiceTypeAndKey[key] = list; - } - list.Add(cached); - - // Also add implementation type as a service type (for self-registration) - if(reg.ImplementationType.Name != reg.ServiceType.Name) - { - var implKey = (reg.ImplementationType.Name, reg.Key); - if(!byServiceTypeAndKey.TryGetValue(implKey, out var implList)) - { - implList = []; - byServiceTypeAndKey[implKey] = implList; - } - implList.Add(cached); - } - - // Track closed types for IsService checks - allServiceTypes.Add(reg.ServiceType.Name); - allServiceTypes.Add(reg.ImplementationType.Name); - - // Group by lifetime - prefer registration with decorators for field generation - // Include Instance or Factory in the key to distinguish multiple instance/factory registrations - var instanceOrFactory = reg.Instance ?? reg.Factory?.Path; - var lifetimeKey = (reg.ImplementationType.Name, reg.Key, instanceOrFactory); - var hasDecorators = reg.Decorators.Length > 0; - - var targetMap = reg.Lifetime switch - { - ServiceLifetime.Singleton => singletonMap, - ServiceLifetime.Scoped => scopedMap, - _ => transientMap - }; - - if(targetMap.TryGetValue(lifetimeKey, out var existing)) - { - // Skip if already seen with decorators, or if current doesn't have decorators - if(existing.HasDecorators || !hasDecorators) - { - continue; - } - } - - targetMap[lifetimeKey] = (cached, hasDecorators); - } - - // Convert maps to lists (preserving insertion order via Dictionary enumeration order in .NET) - var singletons = singletonMap.Values.Select(static v => v.Cached).ToList(); - var scoped = scopedMap.Values.Select(static v => v.Cached).ToList(); - var transients = transientMap.Values.Select(static v => v.Cached).ToList(); - - // Collect service types with multiple registrations for IEnumerable resolution - // Async-init services are excluded from collection resolution (only Task can access them). - var collectionServiceTypes = new List(); - var collectionRegistrations = new Dictionary>(); - - foreach(var kvp in byServiceTypeAndKey) - { - // Include non-keyed service types with multiple registrations - if(kvp.Key.Key is null && kvp.Value.Count > 1) - { - // Filter out async-init registrations β€” they cannot appear in IEnumerable resolvers - var effectiveRegistrations = kvp.Value.Where(static c => !c.IsAsyncInit).ToImmutableEquatableArray(); - - // Deduplicate resolver method names to count unique implementations - var uniqueResolvers = new HashSet(); - foreach(var cached in effectiveRegistrations) - { - uniqueResolvers.Add(cached.ResolverMethodName); - } - - // Only generate collection if there are multiple unique resolvers - if(uniqueResolvers.Count > 1) - { - collectionServiceTypes.Add(kvp.Key.ServiceType); - collectionRegistrations[kvp.Key.ServiceType] = effectiveRegistrations; - } - } - } - - // Pre-compute reversed lists for disposal (to avoid repeated .Reverse() calls) - var reversedSingletons = new List(singletons.Count); - for(var i = singletons.Count - 1; i >= 0; i--) - { - reversedSingletons.Add(singletons[i]); - } - - var reversedScoped = new List(scoped.Count); - for(var i = scoped.Count - 1; i >= 0; i--) - { - reversedScoped.Add(scoped[i]); - } - - // Convert to immutable collections - var immutableByServiceTypeAndKey = byServiceTypeAndKey - .ToImmutableEquatableDictionary( - kvp => kvp.Key, - kvp => kvp.Value.ToImmutableEquatableArray()); - - var immutableSingletons = singletons.ToImmutableEquatableArray(); - var immutableScoped = scoped.ToImmutableEquatableArray(); - var immutableTransients = transients.ToImmutableEquatableArray(); - - var eagerSingletons = immutableSingletons - .Where(static c => c.IsEager) - .ToImmutableEquatableArray(); - var eagerScoped = immutableScoped - .Where(static c => c.IsEager) - .ToImmutableEquatableArray(); - - var lazyEntries = CollectContainerLazyEntries( - immutableSingletons, - immutableScoped, - immutableTransients, - immutableByServiceTypeAndKey); - var funcEntries = CollectContainerFuncEntries( - immutableSingletons, - immutableScoped, - immutableTransients, - immutableByServiceTypeAndKey); - var kvpEntries = CollectContainerKvpEntries( - immutableSingletons, - immutableScoped, - immutableTransients, - immutableByServiceTypeAndKey); - - return new ContainerRegistrationGroups( - immutableByServiceTypeAndKey, - allServiceTypes.ToImmutableEquatableSet(), - immutableSingletons, - immutableScoped, - immutableTransients, - eagerSingletons, - eagerScoped, - lazyEntries, - funcEntries, - kvpEntries, - hasOpenGenerics, - hasKeyedServices, - collectionServiceTypes.ToImmutableEquatableArray(), - collectionRegistrations.ToImmutableEquatableDictionary(), - reversedSingletons.ToImmutableEquatableArray(), - reversedScoped.ToImmutableEquatableArray()); - } - - /// - /// Creates a CachedRegistration with pre-computed field and method names. - /// Computes both names in a single pass to avoid redundant string operations. - /// - private static CachedRegistration CreateCachedRegistration( - ServiceRegistrationModel reg, - EagerResolveOptions eagerResolveOptions, - HashSet reservedNames) - { - var (fieldName, methodName) = ComputeServiceNames(reg); - - // Avoid naming conflicts with user-declared partial accessor methods - if(reservedNames.Contains(methodName)) - { - methodName = $"{methodName}_Resolve"; - } - - // Determine if this registration should be eagerly resolved - // Instance registrations are inherently eager (no field caching needed) - // Transient services are not supported for eager resolution - // Async-init services must always be lazy (cannot be started in constructor) - var isAsyncInit = HasAsyncInitMembers(reg); - var isEager = reg.Instance is null && !isAsyncInit && reg.Lifetime switch - { - ServiceLifetime.Singleton => (eagerResolveOptions & EagerResolveOptions.Singleton) != 0, - ServiceLifetime.Scoped => (eagerResolveOptions & EagerResolveOptions.Scoped) != 0, - _ => false // Transient is never eager - }; - - return new CachedRegistration(reg, fieldName, methodName, isEager, isAsyncInit); - } - - /// - /// Returns when the registration has at least one - /// member, making it an async-init service. - /// - private static bool HasAsyncInitMembers(ServiceRegistrationModel reg) - { - foreach(var m in reg.InjectionMembers) - { - if(m.MemberType == InjectionMemberType.AsyncMethod) - return true; - } - return false; - } - - /// - /// Computes both field name and resolver method name for a service in a single pass. - /// This avoids redundant GetSafeIdentifier calls and string operations. - /// - /// A tuple containing (FieldName, ResolverMethodName). - private static (string FieldName, string ResolverMethodName) ComputeServiceNames(ServiceRegistrationModel reg) - { - var implType = reg.ImplementationType; - var typeName = implType switch - { - GenericTypeData { IsOpenGeneric: false } genericTypeData when genericTypeData.Name != genericTypeData.NameWithoutGeneric => genericTypeData.Name, - GenericTypeData genericTypeData => genericTypeData.NameWithoutGeneric, - _ => implType.Name, - }; - var baseName = GetSafeIdentifier(typeName); - var lowerFirstChar = char.ToLowerInvariant(baseName[0]); - var restOfName = baseName[1..]; - - // Handle keyed services - if(reg.Key is not null) - { - var safeKey = GetSafeIdentifier(reg.Key); - return ( - $"_{lowerFirstChar}{restOfName}_{safeKey}", - $"Get{baseName}_{safeKey}" - ); - } - - // Handle instance registrations - include instance name in the method name - if(reg.Instance is not null) - { - var safeInstance = GetSafeIdentifier(reg.Instance); - return ( - $"_{lowerFirstChar}{restOfName}_{safeInstance}", - $"Get{baseName}_{safeInstance}" - ); - } - - // Handle factory registrations - include factory path in the method name - if(reg.Factory is not null) - { - var safeFactory = GetSafeIdentifier(reg.Factory.Path); - return ( - $"_{lowerFirstChar}{restOfName}_{safeFactory}", - $"Get{baseName}_{safeFactory}" - ); - } - - return ( - $"_{lowerFirstChar}{restOfName}", - $"Get{baseName}" - ); - } -} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/IocSourceGenerator.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/IocSourceGenerator.cs deleted file mode 100644 index abf1e00..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/IocSourceGenerator.cs +++ /dev/null @@ -1,338 +0,0 @@ -ο»Ώnamespace SourceGen.Ioc; - -/// -/// Generates code to register types marked with SourceGen.Ioc.IocRegisterAttribute/SourceGen.Ioc.IocRegisterForAttribute -/// in Microsoft.Extensions.DependencyInjection container. -/// -[Generator(LanguageNames.CSharp)] -public sealed partial class IocSourceGenerator : IIncrementalGenerator -{ - public void Initialize(IncrementalGeneratorInitializationContext context) - { - // ========== IocRegisterAttribute providers ========== - // IocRegisterAttribute (non-generic) - var registerProvider = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocRegisterAttributeFullName, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformRegister(ctx, ct)) - .Where(static m => m is not null) - .Select(static (m, _) => m!); - - // IocRegisterAttribute - var registerProvider_T1 = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocRegisterAttributeFullName_T1, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformRegisterGeneric(ctx, ct)) - .Where(static m => m is not null) - .Select(static (m, _) => m!); - - // ========== IocRegisterForAttribute providers ========== - // IocRegisterForAttribute (non-generic) - var registerForProvider = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocRegisterForAttributeFullName, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformRegisterFor(ctx, ct)) - .SelectMany(static (m, _) => m); - - // IocRegisterForAttribute - var registerForProvider_T1 = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocRegisterForAttributeFullName_T1, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformRegisterForGeneric(ctx, ct)) - .SelectMany(static (m, _) => m); - - // ========== IocRegisterDefaultsAttribute providers ========== - // Transform IocRegisterDefaultsAttribute to get both DefaultSettings and ImplementationType registrations - // IocRegisterDefaultsAttribute (non-generic) - var defaultSettingsResultProvider = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocRegisterDefaultsAttributeFullName, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformDefaultSettings(ctx, ct)) - .SelectMany(static (m, _) => m); - - // IocRegisterDefaultsAttribute - var defaultSettingsResultProvider_T1 = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocRegisterDefaultsAttributeFullName_T1, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformDefaultSettingsGeneric(ctx, ct)) - .SelectMany(static (m, _) => m); - - // Combine all default settings result providers - var allDefaultSettingsResults = defaultSettingsResultProvider - .Collect() - .Combine(defaultSettingsResultProvider_T1.Collect()) - .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); - - // Pipeline 1: Extract DefaultSettingsModel from results (for default settings map) - var allDefaultSettings = allDefaultSettingsResults - .SelectMany(static (results, _) => results - .Where(static r => r.DefaultSettings is not null) - .Select(static r => r.DefaultSettings!)) - .Collect(); - - // Pipeline 2: Extract RegistrationData from results (for implementation type registrations) - var defaultSettingsImplTypeRegistrations = allDefaultSettingsResults - .SelectMany(static (results, _) => results - .SelectMany(static r => r.ImplementationTypeRegistrations)); - - // Pipeline 3: Extract OpenGenericEntries from results (for factory-based open generic registrations) - var factoryBasedOpenGenericEntries = allDefaultSettingsResults - .SelectMany(static (results, _) => results - .SelectMany(static r => r.OpenGenericEntries)) - .Collect(); - - // ========== IocImportModuleAttribute providers ========== - // Transform IocImportModuleAttribute to get both DefaultSettings and OpenGenericEntries in a single pass - // IocImportModuleAttribute (non-generic) - var importModuleResultProvider = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocImportModuleAttributeFullName, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformImportModule(ctx, ct)) - .SelectMany(static (m, _) => m); - - // IocImportModuleAttribute - var importModuleResultProvider_T1 = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocImportModuleAttributeFullName_T1, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformImportModuleGeneric(ctx, ct)) - .SelectMany(static (m, _) => m); - - // Combine all import module result providers - var allImportModuleResults = importModuleResultProvider - .Collect() - .Combine(importModuleResultProvider_T1.Collect()) - .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); - - // Pipeline 1: Extract DefaultSettingsModel from ImportModuleResult (for imported default settings) - var allImportedDefaultSettings = allImportModuleResults - .SelectMany(static (results, _) => results - .SelectMany(static r => r.DefaultSettings)) - .Collect(); - - // Pipeline 2: Extract OpenGenericEntries from ImportModuleResult (for cross-assembly open generic discovery) - var allImportedOpenGenerics = allImportModuleResults - .SelectMany(static (results, _) => results - .SelectMany(static r => r.OpenGenericEntries)) - .Collect(); - - // Get MSBuild properties from analyzer config options - var msbuildPropertiesProvider = context.AnalyzerConfigOptionsProvider - .Select(static (configOptions, _) => - { - // Try to get RootNamespace from MSBuild property - string? rootNamespace = null; - if(configOptions.GlobalOptions.TryGetValue(Constants.RootNamespaceProperty, out var ns) - && ns is { Length: > 0 } rawRootNamespace - && !string.IsNullOrWhiteSpace(rawRootNamespace)) - { - rootNamespace = rawRootNamespace; - } - - // Try to get custom IoC name from MSBuild property - string? customIocName = null; - if(configOptions.GlobalOptions.TryGetValue(Constants.SourceGenIocNameProperty, out var iocName) - && iocName is { Length: > 0 } rawCustomIocName - && !string.IsNullOrWhiteSpace(rawCustomIocName)) - { - customIocName = rawCustomIocName; - } - - // Try to get default lifetime from MSBuild property - ServiceLifetime? defaultLifetime = null; - if(configOptions.GlobalOptions.TryGetValue(Constants.SourceGenIocDefaultLifetimeProperty, out var lifetimeStr) - && lifetimeStr is { Length: > 0 } rawLifetime - && !string.IsNullOrWhiteSpace(rawLifetime)) - { - var trimmed = rawLifetime.Trim(); - defaultLifetime = trimmed switch - { - _ when trimmed.Equals("singleton", StringComparison.OrdinalIgnoreCase) => ServiceLifetime.Singleton, - _ when trimmed.Equals("scoped", StringComparison.OrdinalIgnoreCase) => ServiceLifetime.Scoped, - _ when trimmed.Equals("transient", StringComparison.OrdinalIgnoreCase) => ServiceLifetime.Transient, - _ => null - }; - } - - // Try to get enabled feature flags from MSBuild property - configOptions.GlobalOptions.TryGetValue(Constants.SourceGenIocFeaturesProperty, out var featuresStr); - var features = IocFeaturesHelper.Parse(featuresStr); - - return new MsBuildProperties(rootNamespace, customIocName, defaultLifetime, features); - }); - - // Combine default settings from current assembly and imported modules - // Current assembly settings take precedence over imported settings - var combinedDefaultSettings = allDefaultSettings - .Combine(allImportedDefaultSettings) - .Combine(msbuildPropertiesProvider) - .Select(static (combined, _) => - { - var ((currentAssembly, imported), msbuildProps) = combined; - // Current assembly settings come first (higher priority), then imported settings (lower priority) - // DefaultSettingsMap uses first-match semantics, so current assembly settings should be added first - var allSettings = currentAssembly.AddRange(imported); - return new DefaultSettingsMap(allSettings, msbuildProps.DefaultLifetime ?? ServiceLifetime.Transient); - }); - - // Collect GetService, GetRequiredService, GetKeyedService, GetRequiredKeyedService, GetServices invocations - var invocations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: static (node, _) => PredicateInvocations(node), - transform: static (ctx, ct) => TransformInvocations(ctx, ct)) - .SelectMany(static (candidates, _) => candidates) - .Collect(); - - // ========== IocDiscoverAttribute providers ========== - // IocDiscoverAttribute (non-generic) - var discoverProvider = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocDiscoverAttributeFullName, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformDiscover(ctx, ct)) - .SelectMany(static (m, _) => m); - - // IocDiscoverAttribute - var discoverProvider_T1 = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocDiscoverAttributeFullName_T1, - predicate: static (_, _) => true, - transform: static (ctx, ct) => TransformDiscoverGeneric(ctx, ct)) - .SelectMany(static (m, _) => m); - - // Combine all discover providers - var allDiscoverProviders = discoverProvider - .Collect() - .Combine(discoverProvider_T1.Collect()) - .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); - - // Get compilation info (assembly name and DI package reference) - var compilationInfoProvider = context.CompilationProvider - .Select(static (compilation, _) => - { - var assemblyName = compilation.AssemblyName ?? "Generated"; - // Detect if Microsoft.Extensions.DependencyInjection package is referenced - // by checking for ServiceCollectionContainerBuilderExtensions type - var hasDIPackage = compilation.GetTypeByMetadataName( - "Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions") is INamedTypeSymbol; - return (AssemblyName: assemblyName, HasDIPackage: hasDIPackage); - }); - - // ========== Pipeline 1: Process individual registrations (cacheable per registration) ========== - // Each registration is processed independently with default settings. - - var basicRegistrationResults1 = registerProvider - .Combine(combinedDefaultSettings) - .Select(static (source, _) => ProcessSingleRegistration(source.Left, source.Right)); - - var basicRegistrationResults1_T1 = registerProvider_T1 - .Combine(combinedDefaultSettings) - .Select(static (source, _) => ProcessSingleRegistration(source.Left, source.Right)); - - var basicRegistrationResults2 = registerForProvider - .Combine(combinedDefaultSettings) - .Select(static (source, _) => ProcessSingleRegistration(source.Left, source.Right)); - - var basicRegistrationResults2_T1 = registerForProvider_T1 - .Combine(combinedDefaultSettings) - .Select(static (source, _) => ProcessSingleRegistration(source.Left, source.Right)); - - // Process ImplementationTypes from IocRegisterDefaultsAttribute - // These registrations already have all settings applied from the defaults attribute - // Transform each RegistrationData to BasicRegistrationResult - var basicRegistrationResults3 = defaultSettingsImplTypeRegistrations - .Select(static (registrations, _) => ProcessSingleRegistrationFromDefaults(registrations)); - - // Collect all basic registration results - var allBasicResults = basicRegistrationResults1.Collect() - .Combine(basicRegistrationResults1_T1.Collect()) - .Combine(basicRegistrationResults2.Collect()) - .Combine(basicRegistrationResults2_T1.Collect()) - .Combine(basicRegistrationResults3.Collect()) - .Select(static (combined, _) => - { - var part1 = combined.Left.Left.Left.Left; - var part2 = combined.Left.Left.Left.Right; - var part3 = combined.Left.Left.Right; - var part4 = combined.Left.Right; - var part5 = combined.Right; - - var builder = ImmutableArray.CreateBuilder( - part1.Length + part2.Length + part3.Length + part4.Length + part5.Length); - - builder.AddRange(part1); - builder.AddRange(part2); - builder.AddRange(part3); - builder.AddRange(part4); - builder.AddRange(part5); - - return builder.MoveToImmutable(); - }); - - // ========== Pipeline 2: Combine results and resolve closed generics ========== - - // Combine invocations with discover attributes - var combinedClosedGenericDependencies = invocations - .Combine(allDiscoverProviders) - .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); - - // Combine factory-based open generics with imported open generics from other assemblies - var allOpenGenericEntries = factoryBasedOpenGenericEntries - .Combine(allImportedOpenGenerics) - .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); - - var serviceRegistrations = allBasicResults - .Combine(combinedClosedGenericDependencies) - .Combine(allOpenGenericEntries) - .Select(static (source, ct) => CombineAndResolveClosedGenerics(in source.Left.Left, in source.Left.Right, in source.Right, ct)); - - // Combine service registrations with compilation info and MSBuild properties - var combined = serviceRegistrations - .Combine(compilationInfoProvider) - .Combine(msbuildPropertiesProvider); - - // Generate output - context.RegisterSourceOutput(combined, static (ctx, source) => - { - var ((registrations, compilationInfo), msbuildProps) = source; - // Use RootNamespace from MSBuild if available, otherwise fall back to assembly name - var rootNamespace = msbuildProps.RootNamespace ?? compilationInfo.AssemblyName; - GenerateRegisterOutput(in ctx, registrations, rootNamespace, compilationInfo.AssemblyName, msbuildProps); - }); - - // ========== Container Pipeline ========== - // IocContainerAttribute provider - var containerProvider = context.SyntaxProvider - .ForAttributeWithMetadataName( - Constants.IocContainerAttributeFullName, - predicate: static (node, _) => node is ClassDeclarationSyntax, - transform: static (ctx, ct) => TransformContainer(ctx, ct)) - .Where(static m => m is not null) - .Select(static (m, _) => m!); - - // Combine container with existing serviceRegistrations and group them - var containerWithGroups = containerProvider - .Combine(serviceRegistrations) - .Select(static (source, _) => GroupRegistrationsForContainer(source.Left, source.Right)); - - // Combine with compilation info and MSBuild properties - var containerWithCompilationInfo = containerWithGroups - .Combine(compilationInfoProvider) - .Combine(msbuildPropertiesProvider); - - // Generate Container output (separate from Registration output) - context.RegisterSourceOutput(containerWithCompilationInfo, static (ctx, source) => - { - var ((containerWithGroups, compilationInfo), msbuildProps) = source; - GenerateContainerOutput(in ctx, containerWithGroups, compilationInfo.AssemblyName, msbuildProps, compilationInfo.HasDIPackage); - }); - } - -} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md deleted file mode 100644 index c832a3a..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/SPEC.spec.md +++ /dev/null @@ -1,408 +0,0 @@ -# IocSourceGenerator Specification - -Source generators for compile-time IoC container generation based on `Microsoft.Extensions.DependencyInjection.Abstractions`. This Overview page provides index and consolidated specification data. Feature documentation is split into focused spec files under `Register.*.md` and `Container.*.md`. - -## Spec Index - -Find detailed documentation for each feature: - -### Registration Features - -|Feature|File|Description| -|:---|:---|:---| -|Basic Registration|[Register.Basic.spec.md](Register.Basic.spec.md)|Core service registration patterns including implementation types and keyed services| -|Decorators|[Register.Decorators.spec.md](Register.Decorators.spec.md)|Decorator pattern for composing services with multiple layers| -|Tags|[Register.Tags.spec.md](Register.Tags.spec.md)|Tag-based mutually exclusive service registration| -|Injection Members|[Register.Injection.spec.md](Register.Injection.spec.md)|Field, property, method, async method, and constructor injection patterns| -|Imported Modules|[Register.ImportModule.spec.md](Register.ImportModule.spec.md)|Cross-assembly module importing and sharing registrations| -|Open Generics|[Register.Generics.spec.md](Register.Generics.spec.md)|Generic service types, closed generic discovery, and generic factory mapping| -|IServiceProvider|[Register.ServiceProviderInvocation.spec.md](Register.ServiceProviderInvocation.spec.md)|Automatic service discovery from IServiceProvider invocations| -|MSBuild Configuration|[Register.MSBuild.spec.md](Register.MSBuild.spec.md)|MSBuild property configuration for generator behavior| -|Factory & Instance|[Register.Factory.spec.md](Register.Factory.spec.md)|Factory method and static instance registration| -|KeyValuePair|[Register.KeyValuePair.spec.md](Register.KeyValuePair.spec.md)|KeyValuePair and Dictionary registrations for keyed service collections| - -### Container Features - -|Feature|File|Description| -|:---|:---|:---| -|Basic Container|[Container.Basic.spec.md](Container.Basic.spec.md)|Generated container overview and service resolution| -|Service Lifetime|[Container.Lifetime.spec.md](Container.Lifetime.spec.md)|Singleton, Scoped, and Transient lifecycle management| -|Keyed Services|[Container.KeyedServices.spec.md](Container.KeyedServices.spec.md)|Keyed service resolution with multiple key types| -|Injection|[Container.Injection.spec.md](Container.Injection.spec.md)|Constructor, property, field, and method injection in containers| -|Decorators|[Container.Decorators.spec.md](Container.Decorators.spec.md)|Decorator ordering and composition within containers| -|Imported Modules|[Container.ImportModule.spec.md](Container.ImportModule.spec.md)|FrozenDictionary-based service resolution with module composition| -|Factory & Instance|[Container.Factory.spec.md](Container.Factory.spec.md)|Factory-created and static instance service handling| -|Open Generics|[Container.Generics.spec.md](Container.Generics.spec.md)|Open generic service resolution| -|Collections & Wrappers|[Container.Collections.spec.md](Container.Collections.spec.md)|Collection types (IEnumerable, arrays) and wrapper types (Lazy, Func, Task, KeyValuePair)| -|Container Options|[Container.Options.spec.md](Container.Options.spec.md)|Configuration attributes and behavior flags (IntegrateServiceProvider, ExplicitOnly, etc.)| -|Thread Safety|[Container.ThreadSafety.spec.md](Container.ThreadSafety.spec.md)|Thread-safe service initialization strategies (Lock, SemaphoreSlim, SpinLock, CompareExchange)| -|Partial Accessors|[Container.PartialAccessors.spec.md](Container.PartialAccessors.spec.md)|Fast-path service resolution via partial members| -|MVC & Blazor|[Container.AspNetCore.spec.md](Container.AspNetCore.spec.md)|IControllerActivator, IComponentActivator, and IComponentPropertyActivator support| -|Performance|[Container.Performance.spec.md](Container.Performance.spec.md)|Disposal order, eager resolution, and code generation efficiency| - -## Collecting Information - -### 1. Registration Attributes - -|Attribute|Purpose|Generic Version| -|:--------|:------|:--------------| -|`IocRegisterAttribute`|Mark class for registration|`IocRegisterAttribute`| -|`IocRegisterForAttribute`|Register external types|`IocRegisterForAttribute`| -|`IocRegisterDefaultsAttribute`|Default settings for registrations|`IocRegisterDefaultsAttribute`| -|`IocImportModuleAttribute`|Import other assembly's settings|`IocImportModuleAttribute`| -|`IocDiscoverAttribute`|Explicit closed generic discovery|`IocDiscoverAttribute`| -|`IocGenericFactoryAttribute`|Generic factory type mapping|β€”| - -### 2. Registration Properties - -|Property|Source| -|:-------|:-----| -|Service Type|`TargetServiceType`, `ServiceTypes`, `RegisterAllInterfaces`, `RegisterAllBaseClasses`| -|Implementation Type|`IocRegisterForAttribute.ImplementationType`, marked class, defaults `ImplementationTypes`| -|Lifetime|Attribute β†’ defaults β†’ MSBuild `SourceGenIocDefaultLifetime` β†’ `Transient`| -|Key / KeyType|Attribute β†’ defaults| -|KeyValueType|Resolved `TypeData` of the key value (e.g., `string`, enum, `Guid`). `null` when `KeyType=Csharp` without `nameof()`| -|Decorators|`Decorators` property (with constructor params and type constraints)| -|Tags|Attribute β†’ defaults| -|Factory|`Factory` property (method path, supports generic mapping)| -|Instance|`Instance` property (static instance path, e.g., `"MyService.Default"`)| -|ValidOpenGenericServiceTypes|Set of valid open generic service type names for constraint checking| - -### 3. Type Hierarchy Collection - -|Data|Description| -|:---|:----------| -|`AllInterfaces`|All interfaces implemented by the type| -|`AllBaseClasses`|All base classes (excluding `System.Object`)| -|`TypeParameters`|Generic type parameters with constraints| -|`ConstructorParameters`|Constructor parameters (for decorators)| -|`WrapperKind`|`None`, `Enumerable`, `ReadOnlyCollection`, `Collection`, `ReadOnlyList`, `List`, `Array`, `Lazy`, `Func`, `Task`, `Dictionary`, or `KeyValuePair`| - -### 4. Injection Members - -|Member Type|Resolution| -|:----------|:---------| -|Property|With `[IocInject]`/`[Inject]`, set via object initializer| -|Field|With `[IocInject]`/`[Inject]`, set via object initializer| -|Method|With `[IocInject]`/`[Inject]`, called after construction| -|AsyncMethod|With `[IocInject]`/`[Inject]`, awaited after synchronous member injection when `AsyncMethodInject` is enabled| - -### 5. IServiceProvider Invocations - -Collect service types from invocations: `GetService`, `GetRequiredService`, `GetKeyedService`, `GetRequiredKeyedService`, `GetServices`, `GetKeyedServices` (and non-generic overloads) - -### 6. Compilation Info - -|Property|Source| -|:-------|:-----| -|Root Namespace|MSBuild `RootNamespace` (fallback: assembly name)| -|Assembly Name|Compilation options| -|Custom Method Name|`SourceGenIocName` MSBuild property| -|Default Lifetime|`SourceGenIocDefaultLifetime` MSBuild property (fallback: Transient)| -|Features|`SourceGenIocFeatures` MSBuild property (fallback: `Register,Container,PropertyInject,MethodInject`)| - -### 7. Feature Flags - -The `SourceGenIocFeatures` MSBuild property controls which outputs and injection member kinds are generated. - -Available features: - -|Feature|Value|Description| -|:------|:----|:----------| -|`Register`|`1 << 0`|Enable generation of the registration extension method output.| -|`Container`|`1 << 1`|Enable generation of the container class output.| -|`PropertyInject`|`1 << 2`|Enable property injection member generation.| -|`FieldInject`|`1 << 3`|Enable field injection member generation.| -|`MethodInject`|`1 << 4`|Enable synchronous method injection member generation.| -|`AsyncMethodInject`|`1 << 5`|Enable awaited `[IocInject]`/`[Inject]` methods that return non-generic `Task`. This feature MUST be combined with `MethodInject`; otherwise the analyzer MUST report `SGIOC026`.| - -Default value: - -`Register,Container,PropertyInject,MethodInject` - -`AsyncMethodInject` is **NOT** part of `Default`. - -Behavior: - -- `Register`: Controls whether the registration extension method output is generated. -- `Container`: Controls whether the container class output is generated. -- `PropertyInject` / `FieldInject` / `MethodInject`: Control which synchronous injection member types are included in generated code. -- `AsyncMethodInject`: Controls awaited async method injection for `[IocInject]` methods that return `Task`. - -Feature dependency rules: - -|Condition|Required behavior| -|:--------|:----------------| -|`AsyncMethodInject` enabled and `MethodInject` disabled|The configuration is invalid. The analyzer MUST report `SGIOC026`: `'AsyncMethodInject' feature requires 'MethodInject' to be enabled.`| -|`AsyncMethodInject` omitted|`Task`-returning injection methods are not enabled and MUST NOT participate in generated injection code.| - -Enabling example: - -```xml - - Register,Container,PropertyInject,MethodInject,AsyncMethodInject - -``` - -Parsing rules: - -- Comma-separated values. -- Case-insensitive matching. -- Whitespace is trimmed around each value. -- Invalid values are ignored. - -## Parse Logic - -### 1. Key Interpretation - -|KeyType|Behavior|Example| -|:------|:-------|:------| -|`Value`|Use literal value|`42`, `"myString"`, `MyEnum.Value`| -|`Csharp`|Evaluate as C# expression|`MyClass.StaticField`, `nameof(...)`| - -### 2. Default Settings Priority - -When multiple defaults match an implementation type: - -1. Directly on implementation type -2. On closest base class -3. On first interface in `AllInterfaces` - -### 2.1 Service Type Determination for `ImplementationTypes` - -When `IocRegisterDefaults` provides `ImplementationTypes`, the generator derives service types per implementation with the following rules: - -|Condition|Behavior| -|:--------|:-------| -|Implementation type is open generic|MUST use `TargetServiceType` and append configured `ServiceTypes` (if any).| -|Implementation type is closed generic or non-generic and matching closed types are found from `AllInterfaces`/`AllBaseClasses`|MUST use those matched closed types as service types.| -|Implementation type is closed generic or non-generic and no closed type matches `TargetServiceType` (for example, framework metadata is not visible during generation)|MUST fall back to `TargetServiceType` directly instead of leaving `ServiceTypes` empty.| - -This fallback is required for scenarios such as Razor components where `IComponent` might not be visible to the source generator from the implementation type hierarchy. - -```csharp -// Valid: fallback to TargetServiceType when the hierarchy scan cannot resolve IComponent. -[assembly: IocRegisterDefaults( - typeof(Microsoft.AspNetCore.Components.IComponent), - ServiceLifetime.Scoped, - ImplementationTypes = [typeof(MyAppComponent)])] - -public partial class MyAppComponent : Microsoft.AspNetCore.Components.ComponentBase -{ -} -``` - -```csharp -// Invalid outcome (must not happen): ServiceTypes becomes empty for MyAppComponent. -// Required behavior is to include TargetServiceType as fallback. -``` - -### 3. Settings Merge Order - -`Explicit attribute` β†’ `Matching defaults` β†’ `MSBuild SourceGenIocDefaultLifetime` β†’ `Transient` - -### 4. Inject Attribute Matching - -Match by name only: `IocInjectAttribute` or `InjectAttribute` -(Supports third-party attributes like `Microsoft.AspNetCore.Components.InjectAttribute`) - -### 5. Constructor Selection - -|Priority|Condition| -|-------:|:--------| -|1|Marked with `[IocInject]`| -|2|Primary constructor| -|3|Constructor with most parameters| - -### 6. Parameter Resolution - -|Condition|Action| -|:--------|:-----| -|`[ServiceKey]` attribute|Inject registration key| -|`[FromKeyedServices]` or `[IocInject(Key=...)]`|Keyed service resolution| -|`IServiceProvider` type|Pass provider directly| -|Collection types (`IEnumerable`, `T[]`, etc.)|Extract `T` as service type| - -### 7. Property/Field Injection - -Only members with `[IocInject]` or `[Inject]`: - -|Condition|Behavior| -|:--------|:-------| -|With `Key`|Keyed service resolution| -|`IServiceProvider`|Pass provider directly| -|Collection types|Extract inner type as service| -|Nullable type|Assign resolved nullable value| -|Has default value|Use resolved if non-null| - -### 8. Wrapper Kind Resolution - -`WrapperKind` is a unified enum. Each value has a dedicated `TypeData` derived type. - -|`WrapperKind`|TypeData Type|Types|Resolution| -|:------------|:------------|:----|:---------| -|`Enumerable`|`EnumerableTypeData`|`IEnumerable`|MS.E.DI native collection support| -|`ReadOnlyCollection`|`ReadOnlyCollectionTypeData`|`IReadOnlyCollection`|`GetServices().ToArray()`| -|`Collection`|`CollectionTypeData`|`ICollection`|`GetServices().ToArray()`| -|`ReadOnlyList`|`ReadOnlyListTypeData`|`IReadOnlyList`|`GetServices().ToArray()`| -|`List`|`ListTypeData`|`IList`|`GetServices().ToArray()`| -|`Array`|`ArrayTypeData`|`T[]`|`GetServices().ToArray()`| -|`Lazy`|`LazyTypeData`|`Lazy`|Lazy-initialized service wrapper| -|`Func`|`FuncTypeData`|`Func` / `Func`|Factory delegate wrapper| -|`Task`|`TaskTypeData`|`Task`|Async-init wrapper; resolve `Task` directly for async-init services or wrap sync resolution with `Task.FromResult(...)` for sync-only services.| -|`Dictionary`|`DictionaryTypeData`|`IDictionary`|Dictionary of keyed services| -|`KeyValuePair`|`KeyValuePairTypeData`|`KeyValuePair`|Single keyed service entry| - -#### Type Hierarchy - -```tree -TypeData -└── GenericTypeData - β”œβ”€β”€ TypeParameterTypeData - └── WrapperTypeData (WrapperKind) - β”œβ”€β”€ CollectionWrapperTypeData - β”‚ β”œβ”€β”€ EnumerableTypeData (Enumerable) - β”‚ β”œβ”€β”€ ReadOnlyCollectionTypeData (ReadOnlyCollection) - β”‚ β”œβ”€β”€ CollectionTypeData (Collection) - β”‚ β”œβ”€β”€ ReadOnlyListTypeData (ReadOnlyList) - β”‚ β”œβ”€β”€ ListTypeData (List) - β”‚ └── ArrayTypeData (Array) - β”œβ”€β”€ LazyTypeData (Lazy) - β”œβ”€β”€ FuncTypeData (Func) - β”œβ”€β”€ TaskTypeData (Task) - β”œβ”€β”€ DictionaryTypeData (Dictionary) - └── KeyValuePairTypeData (KeyValuePair) -``` - -Wrapper types support nesting. For example, `IEnumerable>` is parsed as: - -- `EnumerableTypeData` (`WrapperKind.Enumerable`) - - `TypeParameters[0].Type` = `LazyTypeData` (`WrapperKind.Lazy`) - - `TypeParameters[0].Type` = `TypeData` (`IMyService`) - -### 9. Generic Factory Type Mapping - -`IocGenericFactoryAttribute` maps service type parameters to factory method type parameters: - -```csharp -// Single type parameter: IRequestHandler<> -[IocRegisterDefaults(typeof(IRequestHandler<>), Factory = nameof(Create))] -public class FactoryContainer -{ - // typeof(int) is placeholder, maps to T - [IocGenericFactory(typeof(IRequestHandler>), typeof(int))] - public static IRequestHandler Create() => new Handler(); -} - -// Multiple type parameters: IRequestHandler<,> -[IocRegisterDefaults(typeof(IRequestHandler<,>), Factory = nameof(Create))] -public class FactoryContainer -{ - // decimal β†’ T1, int β†’ T2 - [IocGenericFactory(typeof(IRequestHandler, decimal>), typeof(decimal), typeof(int))] - public static IRequestHandler Create() => new Handler(); -} -``` - -## Generators - -1. Registration generator: generate `IServiceCollection` register code.\ -[Registration features spec](#registration-features) - -2. Container generator: generate container that implement `IServiceProvider`.\ -[Container features spec](#container-features) - -## Implementation Requirements - -### Source Generator Architecture - -Implemented at `IocSourceGenerator`, using the Incremental Generator pattern. -The generated code requires .NET 10.0 or later. - -```filetree -src/SourceGen.Ioc.SourceGenerator/ -β”œβ”€β”€ Generator/ -β”‚ β”œβ”€β”€ IocSourceGenerator.cs # Main generator (partial) with Initialize() -β”‚ β”œβ”€β”€ Transform*.cs # Attribute β†’ model transforms (Register, DefaultSettings, ImportModule, Discover, Container) -β”‚ β”œβ”€β”€ ProcessSingleRegistration.cs # Apply defaults to individual registrations -β”‚ β”œβ”€β”€ CombineAndResolveClosedGenerics.cs # Combine results & resolve closed generics from open generics -β”‚ β”œβ”€β”€ IServiceProviderInvocations.cs # Collect IServiceProvider invocations -β”‚ β”œβ”€β”€ GroupRegistrationsForContainer.cs # Group registrations for container generation -β”‚ β”œβ”€β”€ Generate*Output.cs # Code emitters (Register, Container) -β”‚ β”œβ”€β”€ LazyRegistrationHelper.cs # Lazy wrapper registration helper -β”‚ β”œβ”€β”€ FuncRegistrationHelper.cs # Func wrapper registration helper -β”‚ β”œβ”€β”€ KvpRegistrationHelper.cs # KeyValuePair registration helper -β”‚ └── Spec/ # SPEC.spec.md + Register.*.md + Container.*.md -β”œβ”€β”€ Models/ # Immutable data models (RegistrationData, TypeData, etc.) -└── Analyzer/ # Diagnostic analyzers & SPEC.spec.md -``` - -### Data Flow - -```mermaid -flowchart TB - subgraph IocSourceGenerator.Initialize - subgraph Attribute Providers - IocRegister["[IocRegister]"] - IocRegisterFor["[IocRegisterFor]"] - IocRegisterDefaults["[IocRegisterDefaults]"] - IocImportModule["[IocImportModule]"] - IocDiscover["[IocDiscover]"] - Invocations["IServiceProvider Invocations"] - IocContainer["[IocContainer]"] - end - - subgraph Default Settings - allDefaultSettings["allDefaultSettings"] - allImportedDefaultSettings["allImportedDefaultSettings"] - combinedDefaultSettings["combinedDefaultSettings
(DefaultSettingsMap)"] - end - - subgraph Registration Pipeline - allBasicResults["allBasicResults
ImmutableEquatableArray#lt;ServiceRegistrationWithTags#gt;"] - combinedClosedGenericDependencies["combinedClosedGenericDependencies
(Invocations + Discover)"] - allOpenGenericEntries["allOpenGenericEntries
(Factory-based + Imported)"] - CombineResolve["CombineAndResolveClosedGenerics"] - serviceRegistrations["serviceRegistrations
ImmutableEquatableArray#lt;ServiceRegistrationWithTags#gt;"] - end - - subgraph Container Pipeline - ContainerModel["ContainerModel"] - CombineGroup["Combine & Group
(GroupRegistrationsForContainer)"] - ContainerWithGroups["ContainerWithGroups"] - end - - subgraph Output Generation - GenerateRegisterOutput["GenerateRegisterOutput
(.ServiceRegistration.g.cs)"] - GenerateContainerOutput["GenerateContainerOutput
(.Container.g.cs)"] - end - - IocRegisterDefaults --> allDefaultSettings - IocRegisterDefaults -->|OpenGenericEntries| allOpenGenericEntries - IocImportModule --> allImportedDefaultSettings - IocImportModule -->|OpenGenericEntries| allOpenGenericEntries - allDefaultSettings --> combinedDefaultSettings - allImportedDefaultSettings --> combinedDefaultSettings - - IocRegister --> allBasicResults - IocRegisterFor --> allBasicResults - IocRegisterDefaults -->|ImplementationTypes| allBasicResults - combinedDefaultSettings -.->|applied to| allBasicResults - - IocDiscover --> combinedClosedGenericDependencies - Invocations --> combinedClosedGenericDependencies - - allBasicResults --> CombineResolve - combinedClosedGenericDependencies --> CombineResolve - allOpenGenericEntries --> CombineResolve - CombineResolve --> serviceRegistrations - - IocContainer --> ContainerModel - ContainerModel --> CombineGroup - serviceRegistrations --> CombineGroup - CombineGroup --> ContainerWithGroups - - serviceRegistrations --> GenerateRegisterOutput - ContainerWithGroups --> GenerateContainerOutput - end -``` diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/GlobalUsings.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/GlobalUsings.cs index 52eed52..b8ad6d2 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/GlobalUsings.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/GlobalUsings.cs @@ -10,4 +10,5 @@ global using PolyType.Roslyn; global using SourceGen.Ioc.SourceGenerator; global using SourceGen.Ioc.SourceGenerator.Models; -global using static SourceGen.Ioc.SourceGenerator.RoslynExtensions; \ No newline at end of file +global using static SourceGen.Ioc.SourceGenerator.RoslynExtensions; +global using static SourceGen.Ioc.SourceGenerator.Roslyn.TypeParameterSubstitution; \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/ContainerGroupingModels.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/ContainerGroupingModels.cs new file mode 100644 index 0000000..6d7254e --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/ContainerGroupingModels.cs @@ -0,0 +1,29 @@ +using ContainerEntryModel = SourceGen.Ioc.IocSourceGenerator.ContainerEntry; + +namespace SourceGen.Ioc.SourceGenerator.Models; + +/// +/// Helper record to group registrations for container generation. +/// Uses immutable collections for proper incremental generator caching. +/// +/// Last registration wins lookup by service type and key using container entry models. +/// All unique service type names for IsService checks. +/// Whether there are any open generic registrations. +/// Whether there are any keyed service registrations. +/// Service types with multiple implementations for IEnumerable resolution. +/// Parallel service entry model for singleton registrations. +/// Parallel service entry model for scoped registrations. +/// Parallel service entry model for transient registrations. +/// Parallel wrapper entry model for Lazy/Func/KVP registrations. +/// Parallel collection entry model for IEnumerable and related wrappers. +internal sealed record ContainerRegistrationGroups( + ImmutableEquatableDictionary<(string ServiceType, string? Key), ContainerEntryModel> LastWinsByServiceType, + ImmutableEquatableSet AllServiceTypes, + bool HasOpenGenerics, + bool HasKeyedServices, + ImmutableEquatableArray CollectionServiceTypes, + ImmutableEquatableArray SingletonEntries, + ImmutableEquatableArray ScopedEntries, + ImmutableEquatableArray TransientEntries, + ImmutableEquatableArray WrapperEntries, + ImmutableEquatableArray CollectionEntries); \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/GroupRegistrationsForContainer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/GroupRegistrationsForContainer.cs new file mode 100644 index 0000000..ef903de --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/GroupRegistrationsForContainer.cs @@ -0,0 +1,1154 @@ +ο»Ώusing static SourceGen.Ioc.SourceGenerator.Models.Constants; + +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Transforms container and registrations into grouped data for code generation. + /// This step is separated from output generation to enable incremental generator caching. + /// + private static ContainerWithGroups GroupRegistrationsForContainer( + ContainerModel container, + ImmutableEquatableArray registrations, + IocFeatures features) + { + var reservedNames = GetReservedNames(container); + + // Group registrations for code generation + var groups = BuildContainerRegistrationGroups(registrations, features, container.ThreadSafeStrategy, container.EagerResolveOptions, reservedNames); + + return new ContainerWithGroups(container, groups); + } + + /// + /// Filters registrations based on the container's IncludeTags settings. + /// + private static ImmutableEquatableArray FilterRegistrationsForContainer( + ContainerModel container, + ImmutableEquatableArray allRegistrations) + { + // Apply IncludeTags filtering if specified + if(container.IncludeTags.Length > 0) + { + // Only include services that have at least one matching tag + return allRegistrations + .Where(r => r.Tags.Length > 0 && r.Tags.Any(tag => container.IncludeTags.Contains(tag, StringComparer.Ordinal))) + .Select(static r => r.Registration) + .ToImmutableEquatableArray(); + } + + // Include all registrations from the assembly + return allRegistrations + .Select(static r => r.Registration) + .ToImmutableEquatableArray(); + } + + /// + /// Groups explicit registrations for ExplicitOnly containers. + /// + private static ContainerWithGroups GroupExplicitOnlyRegistrations( + ContainerModel container, + IocFeatures features) + { + var registrations = new List(container.ExplicitRegistrations.Length); + + foreach(var explicitReg in container.ExplicitRegistrations) + { + var processed = ProcessExplicitRegistrationForContainer(explicitReg); + if(processed is not null) + { + registrations.Add(processed); + } + } + + var reservedNames = GetReservedNames(container); + + var groups = BuildContainerRegistrationGroups( + registrations.ToImmutableEquatableArray(), + features, + container.ThreadSafeStrategy, + container.EagerResolveOptions, + reservedNames); + + return new ContainerWithGroups(container, groups); + } + + /// + /// Collect partial accessor method names to avoid naming conflicts + /// + private static HashSet GetReservedNames(ContainerModel container) + { + var reservedNames = new HashSet(StringComparer.Ordinal); + foreach(var accessor in container.PartialAccessors) + { + if(accessor.Kind == PartialAccessorKind.Method) + { + reservedNames.Add(accessor.Name); + } + } + return reservedNames; + } + + /// + /// Converts explicit RegistrationData to ServiceRegistrationModel for container generation. + /// + private static ServiceRegistrationModel? ProcessExplicitRegistrationForContainer(RegistrationData data) + { + // Get the first service type, or use implementation type + var serviceType = data.ServiceTypes.Length > 0 + ? data.ServiceTypes[0] + : data.ImplementationType; + + return new ServiceRegistrationModel( + serviceType, + data.ImplementationType, + data.Lifetime, + data.Key, + data.KeyType, + data.KeyValueType, + data.ImplementationType is GenericTypeData { IsOpenGeneric: true }, + data.Decorators, + data.InjectionMembers, + data.Factory, + data.Instance); + } + + /// + /// Groups registrations by service type and key for efficient lookup and collection resolution. + /// Pre-computes field names, method names, and disposal lists to avoid redundant calculations. + /// + private static ContainerRegistrationGroups BuildContainerRegistrationGroups( + ImmutableEquatableArray registrations, + IocFeatures features, + ThreadSafeStrategy threadSafeStrategy, + EagerResolveOptions eagerResolveOptions, + HashSet reservedNames) + { + // Group by (ServiceType.Name, Key) for efficient lookup + var serviceLookup = new Dictionary<(string ServiceType, string? Key), List>(); + var lastWinsLookup = new Dictionary<(string ServiceType, string? Key), ServiceLookupEntry>(); + + // Track all unique service types for IsService checks + var allServiceTypes = new HashSet(); + + // Track unique implementations per lifetime using Dictionary instead of List + index tracking. + // Key: (ImplementationName, ServiceKey, InstanceOrFactory), Value: (ServiceLookupEntry, HasDecorators) + var singletonMap = new Dictionary<(string ImplName, string? Key, string? InstanceOrFactory), (ServiceLookupEntry Cached, bool HasDecorators)>(); + var scopedMap = new Dictionary<(string ImplName, string? Key, string? InstanceOrFactory), (ServiceLookupEntry Cached, bool HasDecorators)>(); + var transientMap = new Dictionary<(string ImplName, string? Key, string? InstanceOrFactory), (ServiceLookupEntry Cached, bool HasDecorators)>(); + + var hasAllInjectionFeatures = IocFeaturesHelper.HasAllInjectionFeatures(features); + var hasOpenGenerics = false; + var hasKeyedServices = false; + + foreach(var reg in registrations) + { + var effectiveRegistration = hasAllInjectionFeatures ? reg : FilterRegistrationForFeatures(reg, features); + + if(effectiveRegistration.IsOpenGeneric) + { + hasOpenGenerics = true; + // Skip open generics for most processing but track the flag + continue; + } + + // Pre-compute field and method names once, including IsEager flag + var cached = CreateServiceLookupEntry(effectiveRegistration, eagerResolveOptions, reservedNames); + + var key = (effectiveRegistration.ServiceType.Name, effectiveRegistration.Key); + if(effectiveRegistration.Key is not null) + { + hasKeyedServices = true; + } + + if(!serviceLookup.TryGetValue(key, out var list)) + { + list = []; + serviceLookup[key] = list; + } + list.Add(cached); + lastWinsLookup[key] = cached; + + // Also add implementation type as a service type (for self-registration) + if(effectiveRegistration.ImplementationType.Name != effectiveRegistration.ServiceType.Name) + { + var implKey = (effectiveRegistration.ImplementationType.Name, effectiveRegistration.Key); + if(!serviceLookup.TryGetValue(implKey, out var implList)) + { + implList = []; + serviceLookup[implKey] = implList; + } + implList.Add(cached); + lastWinsLookup[implKey] = cached; + } + + // Track closed types for IsService checks + allServiceTypes.Add(effectiveRegistration.ServiceType.Name); + allServiceTypes.Add(effectiveRegistration.ImplementationType.Name); + + // Group by lifetime - prefer registration with decorators for field generation + // Include Instance or Factory in the key to distinguish multiple instance/factory registrations + var instanceOrFactory = effectiveRegistration.Instance ?? effectiveRegistration.Factory?.Path; + var lifetimeKey = (effectiveRegistration.ImplementationType.Name, effectiveRegistration.Key, instanceOrFactory); + var hasDecorators = effectiveRegistration.Decorators.Length > 0; + + var targetMap = effectiveRegistration.Lifetime switch + { + ServiceLifetime.Singleton => singletonMap, + ServiceLifetime.Scoped => scopedMap, + _ => transientMap + }; + + if(targetMap.TryGetValue(lifetimeKey, out var existing)) + { + // Skip if already seen with decorators, or if current doesn't have decorators + if(existing.HasDecorators || !hasDecorators) + { + continue; + } + } + + targetMap[lifetimeKey] = (cached, hasDecorators); + } + + // Convert maps to lists (preserving insertion order via Dictionary enumeration order in .NET) + var singletons = singletonMap.Values.Select(static v => v.Cached).ToImmutableEquatableArray(); + var scoped = scopedMap.Values.Select(static v => v.Cached).ToImmutableEquatableArray(); + var transients = transientMap.Values.Select(static v => v.Cached).ToImmutableEquatableArray(); + + // Collect service types with multiple registrations for IEnumerable resolution + // Async-init services are excluded from collection resolution (only Task can access them). + var collectionServiceTypes = new List(); + var collectionRegistrations = new Dictionary>(); + + foreach(var kvp in serviceLookup) + { + // Include non-keyed service types with multiple registrations + if(kvp.Key.Key is null && kvp.Value.Count > 1) + { + // Filter out async-init registrations β€” they cannot appear in IEnumerable resolvers + var effectiveRegistrations = kvp.Value.Where(static c => !c.IsAsyncInit).ToImmutableEquatableArray(); + + // Deduplicate resolver method names to count unique implementations + var uniqueResolvers = new HashSet(); + foreach(var cached in effectiveRegistrations) + { + uniqueResolvers.Add(cached.ResolverMethodName); + } + + // Only generate collection if there are multiple unique resolvers + if(uniqueResolvers.Count > 1) + { + collectionServiceTypes.Add(kvp.Key.ServiceType); + collectionRegistrations[kvp.Key.ServiceType] = effectiveRegistrations; + } + } + } + + var serviceLookupEntries = serviceLookup.ToDictionary( + static kvp => kvp.Key, + static kvp => kvp.Value.ToImmutableEquatableArray()); + + var wrapperEntries = CreateWrapperContainerEntries( + singletons, + scoped, + transients, + serviceLookupEntries); + + var lazyFieldByResolver = wrapperEntries + .OfType() + .ToDictionary(static e => e.InnerResolverMethodName, static e => e.FieldName, StringComparer.Ordinal); + var funcFieldByResolver = wrapperEntries + .OfType() + .ToDictionary(static e => e.InnerResolverMethodName, static e => e.FieldName, StringComparer.Ordinal); + + var singletonEntries = CreateServiceContainerEntries( + singletons, + threadSafeStrategy, + serviceLookupEntries, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + var scopedEntries = CreateServiceContainerEntries( + scoped, + threadSafeStrategy, + serviceLookupEntries, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + var transientEntries = CreateServiceContainerEntries( + transients, + threadSafeStrategy, + serviceLookupEntries, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + var collectionEntries = CreateCollectionContainerEntries(collectionRegistrations); + var immutableLastWinsByServiceType = CreateLastWinsByServiceType( + lastWinsLookup, + singletonEntries, + scopedEntries, + transientEntries); + + return new ContainerRegistrationGroups( + immutableLastWinsByServiceType, + allServiceTypes.ToImmutableEquatableSet(), + hasOpenGenerics, + hasKeyedServices, + collectionServiceTypes.ToImmutableEquatableArray(), + singletonEntries, + scopedEntries, + transientEntries, + wrapperEntries, + collectionEntries); + } + + private static ImmutableEquatableDictionary<(string ServiceType, string? Key), IocSourceGenerator.ContainerEntry> CreateLastWinsByServiceType( + Dictionary<(string ServiceType, string? Key), ServiceLookupEntry> lastWinsLookup, + ImmutableEquatableArray singletonEntries, + ImmutableEquatableArray scopedEntries, + ImmutableEquatableArray transientEntries) + { + var entryByResolverMethodName = new Dictionary(StringComparer.Ordinal); + var entryByResolverAndServiceType = new Dictionary<(string ResolverMethodName, string ServiceType, string? Key), IocSourceGenerator.ContainerEntry>(); + + foreach(var entry in singletonEntries) + AddEntriesByLookup(entryByResolverMethodName, entryByResolverAndServiceType, entry); + foreach(var entry in scopedEntries) + AddEntriesByLookup(entryByResolverMethodName, entryByResolverAndServiceType, entry); + foreach(var entry in transientEntries) + AddEntriesByLookup(entryByResolverMethodName, entryByResolverAndServiceType, entry); + + var lastWinsByServiceType = new Dictionary<(string ServiceType, string? Key), IocSourceGenerator.ContainerEntry>(lastWinsLookup.Count); + + foreach(var kvp in lastWinsLookup) + { + var cached = kvp.Value; + + if(!entryByResolverAndServiceType.TryGetValue((cached.ResolverMethodName, cached.Registration.ServiceType.Name, cached.Registration.Key), out var entry) + && !entryByResolverMethodName.TryGetValue(cached.ResolverMethodName, out entry)) + { + continue; + } + + if(entry is ServiceContainerEntry serviceEntry + && (!string.Equals(serviceEntry.Registration.ServiceType.Name, cached.Registration.ServiceType.Name, StringComparison.Ordinal) + || !string.Equals(serviceEntry.Registration.Key, cached.Registration.Key, StringComparison.Ordinal))) + { + entry = CloneServiceEntryWithRegistration(entry, cached.Registration); + } + + lastWinsByServiceType[kvp.Key] = entry; + } + + return lastWinsByServiceType.ToImmutableEquatableDictionary(); + } + + private static void AddEntriesByLookup( + Dictionary entryByResolverMethodName, + Dictionary<(string ResolverMethodName, string ServiceType, string? Key), IocSourceGenerator.ContainerEntry> entryByResolverAndServiceType, + IocSourceGenerator.ContainerEntry entry) + { + AddEntryByResolverMethodName(entryByResolverMethodName, entry); + + if(entry is not ServiceContainerEntry serviceEntry) + { + return; + } + + var resolverKey = ( + serviceEntry.ResolverMethodName, + serviceEntry.Registration.ServiceType.Name, + serviceEntry.Registration.Key); + + if(!entryByResolverAndServiceType.ContainsKey(resolverKey)) + { + entryByResolverAndServiceType[resolverKey] = entry; + } + } + + private static IocSourceGenerator.ContainerEntry CloneServiceEntryWithRegistration( + IocSourceGenerator.ContainerEntry entry, + ServiceRegistrationModel registration) + { + return entry switch + { + InstanceContainerEntry instance => instance with { Registration = registration }, + EagerContainerEntry eager => eager with { Registration = registration }, + LazyThreadSafeContainerEntry lazy => lazy with { Registration = registration }, + TransientContainerEntry transient => transient with { Registration = registration }, + AsyncContainerEntry asyncEntry => asyncEntry with { Registration = registration }, + AsyncTransientContainerEntry asyncTransient => asyncTransient with { Registration = registration }, + _ => entry + }; + } + + /// + /// Creates a ServiceLookupEntry with pre-computed field and method names. + /// Computes both names in a single pass to avoid redundant string operations. + /// + private static ServiceLookupEntry CreateServiceLookupEntry( + ServiceRegistrationModel reg, + EagerResolveOptions eagerResolveOptions, + HashSet reservedNames) + { + var (fieldName, methodName) = ComputeServiceNames(reg); + + // Avoid naming conflicts with user-declared partial accessor methods + if(reservedNames.Contains(methodName)) + { + methodName = $"{methodName}_Resolve"; + } + + // Determine if this registration should be eagerly resolved via EagerResolveOptions. + // Instance registrations are inherently eager (no field caching needed). + // Transient services are not supported for eager resolution. + var isAsyncInit = HasAsyncInitMembers(reg); + var isEager = reg.Instance is null && reg.Lifetime switch + { + ServiceLifetime.Singleton => (eagerResolveOptions & EagerResolveOptions.Singleton) != 0, + ServiceLifetime.Scoped => (eagerResolveOptions & EagerResolveOptions.Scoped) != 0, + _ => false // Transient is never eager + }; + + return new ServiceLookupEntry(reg, methodName, fieldName, isAsyncInit, isEager); + } + + /// + /// Returns when the registration has at least one + /// member, making it an async-init service. + /// + private static bool HasAsyncInitMembers(ServiceRegistrationModel reg) + { + foreach(var m in reg.InjectionMembers) + { + if(m.MemberType == InjectionMemberType.AsyncMethod) + return true; + } + return false; + } + + /// + /// Computes both field name and resolver method name for a service in a single pass. + /// This avoids redundant GetSafeIdentifier calls and string operations. + /// + /// A tuple containing (FieldName, ResolverMethodName). + private static (string FieldName, string ResolverMethodName) ComputeServiceNames(ServiceRegistrationModel reg) + { + var implType = reg.ImplementationType; + var typeName = implType switch + { + GenericTypeData { IsOpenGeneric: false } genericTypeData when genericTypeData.Name != genericTypeData.NameWithoutGeneric => genericTypeData.Name, + GenericTypeData genericTypeData => genericTypeData.NameWithoutGeneric, + _ => implType.Name, + }; + var baseName = GetSafeIdentifier(typeName); + var lowerFirstChar = char.ToLowerInvariant(baseName[0]); + var restOfName = baseName[1..]; + + // Handle keyed services + if(reg.Key is not null) + { + var safeKey = GetSafeIdentifier(reg.Key); + return ( + $"_{lowerFirstChar}{restOfName}_{safeKey}", + $"Get{baseName}_{safeKey}" + ); + } + + // Handle instance registrations - include instance name in the method name + if(reg.Instance is not null) + { + var safeInstance = GetSafeIdentifier(reg.Instance); + return ( + $"_{lowerFirstChar}{restOfName}_{safeInstance}", + $"Get{baseName}_{safeInstance}" + ); + } + + // Handle factory registrations - include factory path in the method name + if(reg.Factory is not null) + { + var safeFactory = GetSafeIdentifier(reg.Factory.Path); + return ( + $"_{lowerFirstChar}{restOfName}_{safeFactory}", + $"Get{baseName}_{safeFactory}" + ); + } + + return ( + $"_{lowerFirstChar}{restOfName}", + $"Get{baseName}" + ); + } + + private static ImmutableEquatableArray CreateServiceContainerEntries( + ImmutableEquatableArray registrations, + ThreadSafeStrategy threadSafeStrategy, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver) + { + var entries = new List(registrations.Length); + + foreach(var cached in registrations) + { + var reg = cached.Registration; + var constructorParameters = ResolveConstructorParametersForContainerEntryModel( + reg, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttribute: true); + var injectionMembers = ResolveInjectionMembersForContainerEntryModel( + reg.InjectionMembers, + reg, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttributeForMethods: true); + var decorators = ResolveDecoratorsForContainerEntryModel( + reg, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + + entries.Add(CreateServiceContainerEntryModel( + cached, + threadSafeStrategy, + constructorParameters, + injectionMembers, + decorators)); + } + + return entries.ToImmutableEquatableArray(); + } + + private static IocSourceGenerator.ContainerEntry CreateServiceContainerEntryModel( + ServiceLookupEntry cached, + ThreadSafeStrategy threadSafeStrategy, + ImmutableEquatableArray constructorParameters, + ImmutableEquatableArray injectionMembers, + ImmutableEquatableArray decorators) + { + var reg = cached.Registration; + + if(reg.Instance is not null) + { + return new InstanceContainerEntry( + reg, + cached.ResolverMethodName, + constructorParameters, + injectionMembers, + decorators); + } + + if(cached.IsAsyncInit) + { + if(reg.Lifetime is ServiceLifetime.Singleton or ServiceLifetime.Scoped) + { + return new AsyncContainerEntry( + reg, + cached.ResolverMethodName, + cached.FieldName!, + cached.IsEager, + GetEffectiveThreadSafeStrategy(threadSafeStrategy, true), + constructorParameters, + injectionMembers, + decorators); + } + + return new AsyncTransientContainerEntry( + reg, + cached.ResolverMethodName, + constructorParameters, + injectionMembers, + decorators); + } + + if(reg.Lifetime is ServiceLifetime.Singleton or ServiceLifetime.Scoped) + { + if(cached.IsEager) + { + return new EagerContainerEntry( + reg, + cached.ResolverMethodName, + cached.FieldName!, + constructorParameters, + injectionMembers, + decorators); + } + + return new LazyThreadSafeContainerEntry( + reg, + cached.ResolverMethodName, + cached.FieldName!, + threadSafeStrategy, + constructorParameters, + injectionMembers, + decorators); + } + + return new TransientContainerEntry( + reg, + cached.ResolverMethodName, + constructorParameters, + injectionMembers, + decorators); + } + + private static ImmutableEquatableArray CreateWrapperContainerEntries( + ImmutableEquatableArray singletons, + ImmutableEquatableArray scoped, + ImmutableEquatableArray transients, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup) + { + var wrapperEntries = new List(); + + wrapperEntries.AddRange(CreateLazyWrapperContainerEntries(singletons, scoped, transients, serviceLookup)); + wrapperEntries.AddRange(CreateFuncWrapperContainerEntries(singletons, scoped, transients, serviceLookup)); + wrapperEntries.AddRange(CreateKvpWrapperContainerEntries(singletons, scoped, transients, serviceLookup)); + + return wrapperEntries.ToImmutableEquatableArray(); + } + + private static ImmutableEquatableArray CreateCollectionContainerEntries( + IReadOnlyDictionary> collectionRegistrations) + { + var entries = new List(collectionRegistrations.Count); + + foreach(var kvp in collectionRegistrations) + { + var elementResolvers = new List(kvp.Value.Length); + var uniqueKeys = new HashSet(StringComparer.Ordinal); + + foreach(var cached in kvp.Value) + { + var uniqueKey = cached.Registration.Instance ?? cached.ResolverMethodName; + if(!uniqueKeys.Add(uniqueKey)) + continue; + + if(cached.Registration.Instance is not null) + { + elementResolvers.Add(new InstanceExpressionDependency(cached.Registration.Instance)); + } + else + { + elementResolvers.Add(new DirectServiceDependency(cached.ResolverMethodName)); + } + } + + if(elementResolvers.Count < 2) + continue; + + entries.Add(new CollectionContainerEntry( + kvp.Key, + GetArrayResolverMethodName(kvp.Key), + elementResolvers.ToImmutableEquatableArray())); + } + + return entries.ToImmutableEquatableArray(); + } + + private static ImmutableEquatableArray ResolveConstructorParametersForContainerEntryModel( + ServiceRegistrationModel registration, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver, + bool allowServiceKeyAttribute) + { + var constructorParameters = registration.Factory?.AdditionalParameters ?? registration.ImplementationType.ConstructorParameters; + if(constructorParameters is null or { Length: 0 }) + return []; + + var resolved = new List(constructorParameters.Length); + + foreach(var parameter in constructorParameters) + { + var dependency = ResolveParameterDependencyForContainerEntryModel( + parameter, + registration, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttribute); + + resolved.Add(new ResolvedConstructorParameter(parameter, dependency, parameter.IsOptional)); + } + + return resolved.ToImmutableEquatableArray(); + } + + private static ImmutableEquatableArray ResolveInjectionMembersForContainerEntryModel( + ImmutableEquatableArray injectionMembers, + ServiceRegistrationModel ownerRegistration, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver, + bool allowServiceKeyAttributeForMethods) + { + if(injectionMembers.Length == 0) + return []; + + var resolvedMembers = new List(injectionMembers.Length); + + foreach(var member in injectionMembers) + { + ResolvedDependency? dependency = null; + ImmutableEquatableArray parameterDependencies = []; + + switch(member.MemberType) + { + case InjectionMemberType.Property or InjectionMemberType.Field when member.Type is not null: + dependency = ResolveServiceDependencyForContainerEntryModel( + member.Type, + member.Key, + member.IsNullable, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + break; + + case InjectionMemberType.Method or InjectionMemberType.AsyncMethod: + { + var methodParameters = member.Parameters; + if(methodParameters is { Length: > 0 }) + { + var resolvedParameters = new List(methodParameters.Length); + + foreach(var parameter in methodParameters) + { + resolvedParameters.Add(ResolveParameterDependencyForContainerEntryModel( + parameter, + ownerRegistration, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttributeForMethods)); + } + + parameterDependencies = resolvedParameters.ToImmutableEquatableArray(); + if(parameterDependencies.Length == 1) + { + dependency = parameterDependencies[0]; + } + } + + break; + } + } + + resolvedMembers.Add(new ResolvedInjectionMember(member, dependency, parameterDependencies)); + } + + return resolvedMembers.ToImmutableEquatableArray(); + } + + private static ImmutableEquatableArray ResolveDecoratorsForContainerEntryModel( + ServiceRegistrationModel registration, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver) + { + if(registration.Decorators.Length == 0) + return []; + + var decorators = new List(registration.Decorators.Length); + + foreach(var decoratorType in registration.Decorators) + { + var constructorParameters = decoratorType.ConstructorParameters; + var resolvedParameters = new List(constructorParameters?.Length ?? 0); + + if(constructorParameters is { Length: > 1 }) + { + for(var i = 1; i < constructorParameters.Length; i++) + { + var parameter = constructorParameters[i]; + var dependency = ResolveServiceDependencyForContainerEntryModel( + parameter.Type, + parameter.ServiceKey, + parameter.IsOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + + resolvedParameters.Add(new ResolvedConstructorParameter(parameter, dependency, parameter.IsOptional)); + } + } + + var decoratorInjectionMembers = decoratorType.InjectionMembers ?? []; + var resolvedInjectionMembers = ResolveInjectionMembersForContainerEntryModel( + decoratorInjectionMembers, + registration, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttributeForMethods: false); + + var decoratorRegistration = new ServiceRegistrationModel( + registration.ServiceType, + decoratorType, + registration.Lifetime, + registration.Key, + registration.KeyType, + registration.KeyValueType, + decoratorType is GenericTypeData { IsOpenGeneric: true }, + [], + decoratorInjectionMembers, + Factory: null, + Instance: null); + + decorators.Add(new ResolvedDecorator( + decoratorRegistration, + resolvedParameters.ToImmutableEquatableArray(), + resolvedInjectionMembers)); + } + + return decorators.ToImmutableEquatableArray(); + } + + private static ResolvedDependency ResolveParameterDependencyForContainerEntryModel( + ParameterData parameter, + ServiceRegistrationModel ownerRegistration, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver, + bool allowServiceKeyAttribute) + { + if(allowServiceKeyAttribute && parameter.HasServiceKeyAttribute) + { + return new ServiceKeyLiteralDependency(parameter.Type.Name, ownerRegistration.Key ?? "null"); + } + + if(parameter.Type.Name is IServiceProviderTypeName or IServiceProviderGlobalTypeName) + { + return new ServiceProviderSelfDependency(); + } + + return ResolveServiceDependencyForContainerEntryModel( + parameter.Type, + parameter.ServiceKey, + parameter.IsOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + } + + private static ResolvedDependency ResolveServiceDependencyForContainerEntryModel( + TypeData type, + string? key, + bool isOptional, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver) + { + if(type is CollectionWrapperTypeData collectionType) + { + var elementTypeName = collectionType.ElementType.Name; + + if(key is not null) + { + return new CollectionFallbackDependency(elementTypeName, IsKeyed: true, Key: key); + } + + if(collectionType.ElementType is KeyValuePairTypeData kvpElement + && HasKvpRegistrationsForContainerEntryModel(kvpElement.KeyType.Name, kvpElement.ValueType.Name, serviceLookup)) + { + var isArrayType = collectionType.WrapperKind is WrapperKind.ReadOnlyList or WrapperKind.List or WrapperKind.Array; + return isArrayType + ? new KvpResolverDependency(GetKvpArrayResolverMethodName(kvpElement.KeyType.Name, kvpElement.ValueType.Name)) + : new DictionaryResolverDependency(GetKvpDictionaryResolverMethodName(kvpElement.KeyType.Name, kvpElement.ValueType.Name)); + } + + if(collectionRegistrations.ContainsKey(elementTypeName)) + { + return new CollectionDependency(GetArrayResolverMethodName(elementTypeName)); + } + + return new CollectionFallbackDependency(elementTypeName, IsKeyed: false, Key: null); + } + + if(type is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) + { + return ResolveWrapperDependencyForContainerEntryModel( + type, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + useResolverMethods: true); + } + + if(serviceLookup.TryGetValue((type.Name, key), out var registrations)) + { + var cached = registrations[^1]; + if(cached.IsAsyncInit) + { + return cached.Registration.Lifetime == ServiceLifetime.Transient + ? new DirectServiceDependency(GetAsyncCreateMethodName(cached.ResolverMethodName)) + : new DirectServiceDependency(GetAsyncResolverMethodName(cached.ResolverMethodName)); + } + + return new DirectServiceDependency(cached.ResolverMethodName); + } + + return new FallbackProviderDependency(type.Name, key, isOptional); + } + + private static ResolvedDependency ResolveWrapperDependencyForContainerEntryModel( + TypeData type, + string? key, + bool isOptional, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver, + bool useResolverMethods) + { + switch(type) + { + case LazyTypeData lazy: + { + var innerType = lazy.InstanceType; + + if(innerType is not WrapperTypeData && useResolverMethods) + { + if(serviceLookup.TryGetValue((innerType.Name, key), out var innerRegistrations)) + { + var resolverMethodName = innerRegistrations[^1].ResolverMethodName; + if(lazyFieldByResolver.TryGetValue(resolverMethodName, out var fieldName)) + { + return new LazyFieldReferenceDependency(fieldName); + } + } + + return new LazyInlineDependency( + innerType.Name, + new FallbackProviderDependency(innerType.Name, key, isOptional)); + } + + return new LazyInlineDependency( + innerType.Name, + ResolveInnerDependencyForContainerEntryModel( + innerType, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver)); + } + + case FuncTypeData func: + { + var innerType = func.ReturnType; + + if(func.HasInputParameters) + { + if(serviceLookup.TryGetValue((innerType.Name, key), out var innerRegistrations)) + { + var targetRegistration = innerRegistrations[^1].Registration; + return new MultiParamFuncDependency( + innerType.Name, + CreateFuncInputParameters(func.InputTypes), + ResolveConstructorParametersForContainerEntryModel( + targetRegistration, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttribute: true), + ResolveInjectionMembersForContainerEntryModel( + targetRegistration.InjectionMembers, + targetRegistration, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + allowServiceKeyAttributeForMethods: true), + ResolveDecoratorsForContainerEntryModel( + targetRegistration, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver), + targetRegistration.ImplementationType.Name); + } + + return new FallbackProviderDependency(type.Name, key, isOptional); + } + + if(innerType is not WrapperTypeData && useResolverMethods) + { + if(serviceLookup.TryGetValue((innerType.Name, key), out var innerRegistrations)) + { + var resolverMethodName = innerRegistrations[^1].ResolverMethodName; + if(funcFieldByResolver.TryGetValue(resolverMethodName, out var fieldName)) + { + return new FuncFieldReferenceDependency(fieldName); + } + } + + return new FuncInlineDependency( + innerType.Name, + new FallbackProviderDependency(innerType.Name, key, isOptional)); + } + + return new FuncInlineDependency( + innerType.Name, + ResolveInnerDependencyForContainerEntryModel( + innerType, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver)); + } + + case KeyValuePairTypeData kvp: + return new KvpInlineDependency( + kvp.KeyType.Name, + kvp.ValueType.Name, + key ?? "default", + ResolveInnerDependencyForContainerEntryModel( + kvp.ValueType, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver)); + + case DictionaryTypeData dictionary: + { + if(key is null && HasKvpRegistrationsForContainerEntryModel(dictionary.KeyType.Name, dictionary.ValueType.Name, serviceLookup)) + { + return new DictionaryResolverDependency(GetKvpDictionaryResolverMethodName(dictionary.KeyType.Name, dictionary.ValueType.Name)); + } + + var kvpTypeName = $"global::System.Collections.Generic.KeyValuePair<{dictionary.KeyType.Name}, {dictionary.ValueType.Name}>"; + return new DictionaryFallbackDependency(kvpTypeName, IsKeyed: key is not null, Key: key); + } + + case TaskTypeData task: + { + if(serviceLookup.TryGetValue((task.InnerType.Name, key), out var innerRegistrations)) + { + var cached = innerRegistrations[^1]; + if(cached.IsAsyncInit) + { + return new TaskAsyncDependency(GetAsyncResolverMethodName(cached.ResolverMethodName), task.InnerType.Name); + } + + return new TaskFromResultDependency(new DirectServiceDependency(cached.ResolverMethodName), task.InnerType.Name); + } + + return new FallbackProviderDependency(type.Name, key, isOptional); + } + + default: + return ResolveServiceDependencyForContainerEntryModel( + type, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + } + } + + private static ResolvedDependency ResolveInnerDependencyForContainerEntryModel( + TypeData innerType, + string? key, + bool isOptional, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup, + IReadOnlyDictionary> collectionRegistrations, + IReadOnlyDictionary lazyFieldByResolver, + IReadOnlyDictionary funcFieldByResolver) + { + if(innerType is LazyTypeData or FuncTypeData or KeyValuePairTypeData or DictionaryTypeData or TaskTypeData) + { + return ResolveWrapperDependencyForContainerEntryModel( + innerType, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver, + useResolverMethods: false); + } + + return ResolveServiceDependencyForContainerEntryModel( + innerType, + key, + isOptional, + serviceLookup, + collectionRegistrations, + lazyFieldByResolver, + funcFieldByResolver); + } + + private static ImmutableEquatableArray CreateFuncInputParameters(ImmutableEquatableArray inputTypes) + { + if(inputTypes.Length == 0) + return []; + + var parameters = new List(inputTypes.Length); + for(var i = 0; i < inputTypes.Length; i++) + { + parameters.Add(new ParameterData($"arg{i}", inputTypes[i].Type)); + } + + return parameters.ToImmutableEquatableArray(); + } + + private static bool HasKvpRegistrationsForContainerEntryModel( + string keyTypeName, + string valueTypeName, + IReadOnlyDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> serviceLookup) + { + foreach(var kvp in serviceLookup) + { + if(kvp.Key.Key is null) + continue; + + if(!string.Equals(kvp.Key.ServiceType, valueTypeName, StringComparison.Ordinal)) + continue; + + var cached = kvp.Value[^1]; + if(IsKeyTypeCompatible(keyTypeName, cached.Registration.KeyValueType)) + return true; + } + + return false; + } + + private readonly record struct ServiceLookupEntry( + ServiceRegistrationModel Registration, + string ResolverMethodName, + string? FieldName, + bool IsAsyncInit, + bool IsEager); +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/GroupRegistrationsForRegister.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/GroupRegistrationsForRegister.cs new file mode 100644 index 0000000..3f6cab3 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/GroupRegistrationsForRegister.cs @@ -0,0 +1,251 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + private static RegisterOutputModel? GroupRegistrationsForRegister( + ImmutableEquatableArray registrations, + string rootNamespace, + string assemblyName, + string? customIocName, + IocFeatures features) + { + if((features & IocFeatures.Register) == 0 || registrations.Length == 0) + return null; + + var methodBaseName = !string.IsNullOrWhiteSpace(customIocName) + ? GetSafeIdentifier(customIocName!) + : GetSafeIdentifier(assemblyName); + + var canonicalTagsCache = new Dictionary, ImmutableEquatableArray>(); + foreach(var regWithTags in registrations) + { + var tags = regWithTags.Tags; + if(canonicalTagsCache.ContainsKey(tags)) + { + continue; + } + + if(tags.Length > 0) + { + var sortedTags = tags.OrderBy(static t => t, StringComparer.Ordinal).ToImmutableEquatableArray(); + canonicalTagsCache[tags] = sortedTags; + } + else + { + canonicalTagsCache[tags] = []; + } + } + + var shouldFilterInjection = !IocFeaturesHelper.HasAllInjectionFeatures(features); + var groupedRegistrations = new Dictionary, List>(); + + foreach(var regWithTags in registrations) + { + var registration = shouldFilterInjection + ? FilterRegistrationForFeatures(regWithTags.Registration, features) + : regWithTags.Registration; + var tagKey = canonicalTagsCache[regWithTags.Tags]; + + if(!groupedRegistrations.TryGetValue(tagKey, out var group)) + { + group = []; + groupedRegistrations[tagKey] = group; + } + + var entry = CreateRegisterEntry(registration); + if(entry is null) + { + continue; + } + + group.Add(entry); + } + + var lazyByTagKey = GroupLazyEntriesByTagKey(CollectLazyEntries(registrations), canonicalTagsCache); + var funcByTagKey = GroupFuncEntriesByTagKey(CollectFuncEntries(registrations), canonicalTagsCache); + var kvpByTagKey = GroupKvpEntriesByTagKey(CollectKeyValuePairEntries(registrations), canonicalTagsCache); + + var asyncInitServiceTypeSet = new HashSet(StringComparer.Ordinal); + foreach(var group in groupedRegistrations.Values) + { + foreach(var entry in group) + { + if(!entry.Registration.InjectionMembers.Any(static m => m.MemberType == InjectionMemberType.AsyncMethod)) + continue; + + asyncInitServiceTypeSet.Add(entry.Registration.ServiceType.Name); + asyncInitServiceTypeSet.Add(entry.Registration.ImplementationType.Name); + } + } + + var asyncInitServiceTypes = asyncInitServiceTypeSet.Count > 0 + ? asyncInitServiceTypeSet.ToImmutableEquatableSet() + : null; + + var tagGroups = groupedRegistrations + .OrderBy(static kvp => kvp.Key, TagArrayComparer.Instance) + .Select(kvp => + { + var tags = kvp.Key; + + lazyByTagKey.TryGetValue(tags, out var lazyEntries); + funcByTagKey.TryGetValue(tags, out var funcEntries); + kvpByTagKey.TryGetValue(tags, out var kvpEntries); + + return new RegisterTagGroup( + tags, + kvp.Value.ToImmutableEquatableArray(), + lazyEntries is null ? [] : lazyEntries.ToImmutableEquatableArray(), + funcEntries is null ? [] : funcEntries.ToImmutableEquatableArray(), + kvpEntries is null ? [] : kvpEntries.ToImmutableEquatableArray()); + }) + .ToImmutableEquatableArray(); + + return new RegisterOutputModel( + methodBaseName, + rootNamespace, + assemblyName, + tagGroups, + asyncInitServiceTypes); + } + + private static Dictionary, List> GroupLazyEntriesByTagKey( + List entries, + Dictionary, ImmutableEquatableArray> canonicalTagsCache) + { + var grouped = new Dictionary, List>(); + foreach(var entry in entries) + { + var tagKey = canonicalTagsCache[entry.Tags]; + if(!grouped.TryGetValue(tagKey, out var list)) + { + list = []; + grouped[tagKey] = list; + } + + list.Add(entry); + } + + return grouped; + } + + private static Dictionary, List> GroupFuncEntriesByTagKey( + List entries, + Dictionary, ImmutableEquatableArray> canonicalTagsCache) + { + var grouped = new Dictionary, List>(); + foreach(var entry in entries) + { + var tagKey = canonicalTagsCache[entry.Tags]; + if(!grouped.TryGetValue(tagKey, out var list)) + { + list = []; + grouped[tagKey] = list; + } + + list.Add(entry); + } + + return grouped; + } + + private static Dictionary, List> GroupKvpEntriesByTagKey( + List entries, + Dictionary, ImmutableEquatableArray> canonicalTagsCache) + { + var grouped = new Dictionary, List>(); + foreach(var entry in entries) + { + var tagKey = canonicalTagsCache[entry.Tags]; + if(!grouped.TryGetValue(tagKey, out var list)) + { + list = []; + grouped[tagKey] = list; + } + + list.Add(entry); + } + + return grouped; + } + + private static RegisterEntry? CreateRegisterEntry(ServiceRegistrationModel registration) + { + var serviceTypeName = registration.ServiceType.Name; + var implTypeName = registration.ImplementationType.Name; + + bool hasFactory = registration.Factory is not null && !registration.IsOpenGeneric; + bool hasInstance = registration.Instance is not null && !registration.IsOpenGeneric; + bool isServiceTypeRegistration = serviceTypeName != implTypeName; + bool hasClosedDecorators = registration.Decorators.Length > 0 && isServiceTypeRegistration && !registration.IsOpenGeneric; + + bool hasInjectionMembers = registration.InjectionMembers.Length > 0; + bool hasInjectConstructor = registration.ImplementationType.HasInjectConstructor; + var constructorParams = registration.ImplementationType.ConstructorParameters; + bool hasSpecialConstructorParams = constructorParams?.Any(static p => + p.HasInjectAttribute || + p.Type.NeedsWrapperResolution || + (p.IsNullable && p.Type is LazyTypeData { InstanceType: not WrapperTypeData } or FuncTypeData { ReturnType: not WrapperTypeData }) || + p.HasDefaultValue) == true; + bool needsFactoryConstruction = hasInjectionMembers || hasInjectConstructor || hasSpecialConstructorParams; + bool hasAsyncInjectionMembers = registration.InjectionMembers.Any(static m => m.MemberType == InjectionMemberType.AsyncMethod); + bool shouldForwardServiceType = !registration.IsOpenGeneric && isServiceTypeRegistration; + + if(hasFactory) + { + return new FactoryRegisterEntry(registration); + } + + if(hasInstance) + { + if(registration.Lifetime == ServiceLifetime.Singleton) + { + return new InstanceRegisterEntry(registration); + } + + return null; + } + + if(hasClosedDecorators) + { + return new DecoratorRegisterEntry(registration); + } + + if(shouldForwardServiceType) + { + return new ForwardingRegisterEntry(registration); + } + + if(needsFactoryConstruction && !registration.IsOpenGeneric) + { + if(hasAsyncInjectionMembers) + { + return new AsyncInjectionRegisterEntry(registration); + } + else + { + return new InjectionRegisterEntry(registration); + } + } + + return new SimpleRegisterEntry(registration); + } + + private sealed class TagArrayComparer : IComparer> + { + public static readonly TagArrayComparer Instance = new(); + + public int Compare(ImmutableEquatableArray a, ImmutableEquatableArray b) + { + var minLength = Math.Min(a.Length, b.Length); + for(var i = 0; i < minLength; i++) + { + var cmp = StringComparer.Ordinal.Compare(a[i], b[i]); + if(cmp != 0) + return cmp; + } + + return a.Length.CompareTo(b.Length); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/RegisterGroupingModels.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/RegisterGroupingModels.cs new file mode 100644 index 0000000..dbcba1a --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Grouping/RegisterGroupingModels.cs @@ -0,0 +1,14 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Tag-grouped register output data including registrations and wrapper entries. + /// + private sealed record class RegisterTagGroup( + ImmutableEquatableArray Tags, + ImmutableEquatableArray Registrations, + ImmutableEquatableArray LazyEntries, + ImmutableEquatableArray FuncEntries, + ImmutableEquatableArray KvpEntries); +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/IocSourceGenerator.ConfigProviders.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/IocSourceGenerator.ConfigProviders.cs new file mode 100644 index 0000000..f2ab3f9 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/IocSourceGenerator.ConfigProviders.cs @@ -0,0 +1,76 @@ +ο»Ώnamespace SourceGen.Ioc; + +// Helpers that read MSBuild / AnalyzerConfig / Compilation inputs into pipeline-friendly providers. +// Extracted from Initialize() to keep the orchestrator focused on pipeline wiring. +partial class IocSourceGenerator +{ + /// + /// Reads the SourceGenIocDefaultLifetime MSBuild property and parses it into a . + /// Returns null when the property is missing or unrecognised. + /// + private static IncrementalValueProvider BuildDefaultLifetimeProvider( + IncrementalGeneratorInitializationContext context) + => context.AnalyzerConfigOptionsProvider + .Select(static (configOptions, _) => + { + if(!configOptions.GlobalOptions.TryGetValue(Constants.SourceGenIocDefaultLifetimeProperty, out var lifetimeStr) + || string.IsNullOrWhiteSpace(lifetimeStr)) + { + return (ServiceLifetime?)null; + } + + var trimmed = lifetimeStr.Trim(); + return (ServiceLifetime?)(trimmed switch + { + _ when trimmed.Equals("singleton", StringComparison.OrdinalIgnoreCase) => ServiceLifetime.Singleton, + _ when trimmed.Equals("scoped", StringComparison.OrdinalIgnoreCase) => ServiceLifetime.Scoped, + _ when trimmed.Equals("transient", StringComparison.OrdinalIgnoreCase) => ServiceLifetime.Transient, + _ => null, + }); + }); + + /// + /// Reads MSBuild properties (RootNamespace, SourceGenIocName, SourceGenIocFeatures) into a record. + /// + private static IncrementalValueProvider BuildMsBuildPropertiesProvider( + IncrementalGeneratorInitializationContext context) + => context.AnalyzerConfigOptionsProvider + .Select(static (configOptions, _) => + { + var rootNamespace = TryReadNonEmptyProperty(configOptions, Constants.RootNamespaceProperty); + var customIocName = TryReadNonEmptyProperty(configOptions, Constants.SourceGenIocNameProperty); + + configOptions.GlobalOptions.TryGetValue(Constants.SourceGenIocFeaturesProperty, out var featuresStr); + var features = IocFeaturesHelper.Parse(featuresStr); + + return new MsBuildProperties(rootNamespace, customIocName, features); + }); + + private static string? TryReadNonEmptyProperty( + Microsoft.CodeAnalysis.Diagnostics.AnalyzerConfigOptionsProvider configOptions, + string propertyName) + { + if(configOptions.GlobalOptions.TryGetValue(propertyName, out var value) + && !string.IsNullOrWhiteSpace(value)) + { + return value; + } + return null; + } + + /// + /// Reads compilation-level inputs: assembly name (defaults to "Generated") and whether + /// Microsoft.Extensions.DependencyInjection is referenced (detected via ServiceCollectionContainerBuilderExtensions). + /// + private static IncrementalValueProvider<(string AssemblyName, bool HasDIPackage)> BuildCompilationInfoProvider( + IncrementalGeneratorInitializationContext context) + => context.CompilationProvider + .Select(static (compilation, _) => + { + var assemblyName = compilation.AssemblyName ?? "Generated"; + var hasDIPackage = compilation.GetTypeByMetadataName( + "Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions") is not null; + return (AssemblyName: assemblyName, HasDIPackage: hasDIPackage); + }); + +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/IocSourceGenerator.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/IocSourceGenerator.cs new file mode 100644 index 0000000..18455fe --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/IocSourceGenerator.cs @@ -0,0 +1,302 @@ +ο»Ώnamespace SourceGen.Ioc; + +/// +/// Incremental source generator that processes IocRegister*, IocContainer, IocImportModule, +/// IocDiscover, and IocRegisterDefaults attributes (plus IServiceProvider.GetService<T> +/// invocations) and emits two kinds of output: +/// +/// Register* partial methods β€” extension methods that register services into IServiceCollection. +/// Container partial classes β€” standalone DI containers that resolve services without IServiceCollection. +/// +/// +/// The pipeline mirrors the stages described in Spec/SPEC.spec.md: +/// +/// +/// Stage 1 β€” Attribute detection (Transforms/): symbol β†’ data model. +/// Stage 2 β€” Combine MSBuild / compilation / default-settings inputs. +/// Stage 3 β€” Per-registration processing (Processing/ProcessSingleRegistration, cacheable). +/// Stage 4 β€” Closed-generic resolution + grouping (Processing/ + Grouping/). +/// Stage 5 β€” Emit Register output and Container output (Emit/Register/, Emit/Container/). +/// +/// +[Generator(LanguageNames.CSharp)] +public sealed partial class IocSourceGenerator : IIncrementalGenerator +{ + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // ===== Stage 1: Attribute providers (symbol -> data model) ===== + + // [IocRegister] / [IocRegister] + var registerProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocRegisterAttributeFullName, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformRegister(ctx, ct)) + .Where(static m => m is not null) + .Select(static (m, _) => m!); + + var registerProvider_T1 = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocRegisterAttributeFullName_T1, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformRegisterGeneric(ctx, ct)) + .Where(static m => m is not null) + .Select(static (m, _) => m!); + + // [IocRegisterFor] / [IocRegisterFor] + var registerForProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocRegisterForAttributeFullName, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformRegisterFor(ctx, ct)) + .SelectMany(static (m, _) => m); + + var registerForProvider_T1 = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocRegisterForAttributeFullName_T1, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformRegisterForGeneric(ctx, ct)) + .SelectMany(static (m, _) => m); + + // [IocRegisterDefaults] / [IocRegisterDefaults] -> (settings, impl-type registrations, factory open generics) + var allDefaultSettingsResults = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocRegisterDefaultsAttributeFullName, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformDefaultSettings(ctx, ct)) + .SelectMany(static (m, _) => m) + .Collect() + .Combine(context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocRegisterDefaultsAttributeFullName_T1, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformDefaultSettingsGeneric(ctx, ct)) + .SelectMany(static (m, _) => m) + .Collect()) + .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); + + var allDefaultSettings = allDefaultSettingsResults + .SelectMany(static (results, _) => results + .Where(static r => r.DefaultSettings is not null) + .Select(static r => r.DefaultSettings!)) + .Collect(); + + var defaultSettingsImplTypeRegistrations = allDefaultSettingsResults + .SelectMany(static (results, _) => results.SelectMany(static r => r.ImplementationTypeRegistrations)); + + var factoryBasedOpenGenericEntries = allDefaultSettingsResults + .SelectMany(static (results, _) => results.SelectMany(static r => r.OpenGenericEntries)) + .Collect(); + + // [IocImportModule] / [IocImportModule] -> (imported settings, imported open generics) + var allImportModuleResults = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocImportModuleAttributeFullName, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformImportModule(ctx, ct)) + .SelectMany(static (m, _) => m) + .Collect() + .Combine(context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocImportModuleAttributeFullName_T1, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformImportModuleGeneric(ctx, ct)) + .SelectMany(static (m, _) => m) + .Collect()) + .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); + + var allImportedDefaultSettings = allImportModuleResults + .SelectMany(static (results, _) => results.SelectMany(static r => r.DefaultSettings)) + .Collect(); + + var allImportedOpenGenerics = allImportModuleResults + .SelectMany(static (results, _) => results.SelectMany(static r => r.OpenGenericEntries)) + .Collect(); + + // [IocDiscover] / [IocDiscover] -> closed-generic dependencies discovered at compile time + var allDiscoverProviders = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocDiscoverAttributeFullName, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformDiscover(ctx, ct)) + .SelectMany(static (m, _) => m) + .Collect() + .Combine(context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocDiscoverAttributeFullName_T1, + predicate: static (_, _) => true, + transform: static (ctx, ct) => TransformDiscoverGeneric(ctx, ct)) + .SelectMany(static (m, _) => m) + .Collect()) + .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); + + // IServiceProvider.GetService / GetRequiredService / GetKeyedService / GetServices invocations + var invocations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => PredicateInvocations(node), + transform: static (ctx, ct) => TransformInvocations(ctx, ct)) + .SelectMany(static (candidates, _) => candidates) + .Collect(); + + // [IocContainer] + var containerProvider = context.SyntaxProvider + .ForAttributeWithMetadataName( + Constants.IocContainerAttributeFullName, + predicate: static (node, _) => node is ClassDeclarationSyntax, + transform: static (ctx, ct) => TransformContainer(ctx, ct)) + .Where(static m => m is not null) + .Select(static (m, _) => m!); + + // ===== Stage 2: Compilation / MSBuild / default-settings inputs ===== + + var defaultLifetimeProvider = BuildDefaultLifetimeProvider(context); + var msbuildPropertiesProvider = BuildMsBuildPropertiesProvider(context); + var compilationInfoProvider = BuildCompilationInfoProvider(context); + + // Current-assembly settings take precedence over imported settings (DefaultSettingsMap uses first-match semantics). + var combinedDefaultSettings = allDefaultSettings + .Combine(allImportedDefaultSettings) + .Combine(defaultLifetimeProvider) + .Select(static (combined, _) => + { + var ((currentAssembly, imported), defaultLifetime) = combined; + var allSettings = currentAssembly.AddRange(imported); + return new DefaultSettingsMap(allSettings, defaultLifetime ?? ServiceLifetime.Transient); + }); + + // ===== Stage 3: Per-registration processing (cacheable per registration) ===== + + var basicRegistrationResults1 = registerProvider + .Combine(combinedDefaultSettings) + .Select(static (s, ct) => ProcessSingleRegistration(s.Left, s.Right, ct)); + + var basicRegistrationResults1_T1 = registerProvider_T1 + .Combine(combinedDefaultSettings) + .Select(static (s, ct) => ProcessSingleRegistration(s.Left, s.Right, ct)); + + var basicRegistrationResults2 = registerForProvider + .Combine(combinedDefaultSettings) + .Select(static (s, ct) => ProcessSingleRegistration(s.Left, s.Right, ct)); + + var basicRegistrationResults2_T1 = registerForProvider_T1 + .Combine(combinedDefaultSettings) + .Select(static (s, ct) => ProcessSingleRegistration(s.Left, s.Right, ct)); + + // ImplementationTypes from [IocRegisterDefaults] already have all settings applied; just convert. + var basicRegistrationResults3 = defaultSettingsImplTypeRegistrations + .Select(static (registrations, ct) => ProcessSingleRegistrationFromDefaults(registrations, ct)); + + var allBasicResults = CollectAndConcat( + basicRegistrationResults1, + basicRegistrationResults1_T1, + basicRegistrationResults2, + basicRegistrationResults2_T1, + basicRegistrationResults3); + + // ===== Stage 4: Closed-generic resolution + grouping ===== + + var combinedClosedGenericDependencies = invocations + .Combine(allDiscoverProviders) + .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); + + var allOpenGenericEntries = factoryBasedOpenGenericEntries + .Combine(allImportedOpenGenerics) + .Select(static (combined, _) => combined.Left.AddRange(combined.Right)); + + var serviceRegistrations = allBasicResults + .Combine(combinedClosedGenericDependencies) + .Combine(allOpenGenericEntries) + .Select(static (s, ct) => CombineAndResolveClosedGenerics(in s.Left.Left, in s.Left.Right, in s.Right, ct)); + + // ===== Stage 5a: Emit Register output ===== + + var registerOutputModel = serviceRegistrations + .Combine(compilationInfoProvider) + .Combine(msbuildPropertiesProvider) + .Select(static (source, _) => + { + var ((registrations, compilationInfo), msbuildProps) = source; + var rootNamespace = msbuildProps.RootNamespace ?? compilationInfo.AssemblyName; + return GroupRegistrationsForRegister( + registrations, rootNamespace, compilationInfo.AssemblyName, + msbuildProps.CustomIocName, msbuildProps.Features); + }); + + context.RegisterSourceOutput(registerOutputModel, static (ctx, model) => + { + if(model is null) + return; + GenerateRegisterOutput(in ctx, model); + }); + + // ===== Stage 5b: Emit Container output ===== + // ExplicitOnly containers do not depend on serviceRegistrations (independent caching branch). + + var explicitOnlyContainerWithGroups = containerProvider + .Where(static c => c.ExplicitOnly) + .Combine(msbuildPropertiesProvider) + .Select(static (source, _) => + { + var (container, msbuildProps) = source; + return GroupExplicitOnlyRegistrations(container, msbuildProps.Features); + }); + + var normalContainerWithGroups = containerProvider + .Where(static c => !c.ExplicitOnly) + .Combine(serviceRegistrations) + .Select(static (source, _) => + { + var (container, registrations) = source; + var filtered = FilterRegistrationsForContainer(container, registrations); + return (container, filtered); + }) + .Combine(msbuildPropertiesProvider) + .Select(static (source, _) => + { + var ((container, filtered), msbuildProps) = source; + return GroupRegistrationsForContainer(container, filtered, msbuildProps.Features); + }); + + EmitContainerOutput(explicitOnlyContainerWithGroups); + EmitContainerOutput(normalContainerWithGroups); + + void EmitContainerOutput(IncrementalValuesProvider groups) + { + var withInfo = groups.Combine(compilationInfoProvider).Combine(msbuildPropertiesProvider); + context.RegisterSourceOutput(withInfo, static (ctx, source) => + { + var ((containerWithGroups, compilationInfo), msbuildProps) = source; + GenerateContainerOutput(in ctx, containerWithGroups, compilationInfo.AssemblyName, msbuildProps, compilationInfo.HasDIPackage); + }); + } + } + + /// + /// Concatenates five streams into a single + /// of via Collect+Combine. + /// Used to merge the four IocRegister* attribute pipelines plus the [IocRegisterDefaults] + /// implementation-type pipeline into one collection before closed-generic resolution. + /// + private static IncrementalValueProvider> CollectAndConcat( + IncrementalValuesProvider a, + IncrementalValuesProvider b, + IncrementalValuesProvider c, + IncrementalValuesProvider d, + IncrementalValuesProvider e) + => a.Collect() + .Combine(b.Collect()) + .Combine(c.Collect()) + .Combine(d.Collect()) + .Combine(e.Collect()) + .Select(static (combined, _) => + { + var ((((p1, p2), p3), p4), p5) = combined; + var builder = ImmutableArray.CreateBuilder(p1.Length + p2.Length + p3.Length + p4.Length + p5.Length); + builder.AddRange(p1); + builder.AddRange(p2); + builder.AddRange(p3); + builder.AddRange(p4); + builder.AddRange(p5); + return builder.MoveToImmutable(); + }); +} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs index d4e6008..cee042d 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/ContainerWithGroups.cs @@ -8,93 +8,3 @@ internal sealed record ContainerWithGroups( ContainerModel Container, ContainerRegistrationGroups Groups); - -/// -/// Helper record to group registrations for container generation. -/// Uses immutable collections for proper incremental generator caching. -/// -/// Registrations grouped by service type and key for efficient lookup. -/// All unique service type names for IsService checks. -/// Singleton registrations with pre-computed names. -/// Scoped registrations with pre-computed names. -/// Transient registrations with pre-computed names. -/// Singleton registrations that should be eagerly resolved. -/// Scoped registrations that should be eagerly resolved. -/// Pre-computed Lazy wrapper entries for container code generation. -/// Pre-computed Func wrapper entries for container code generation. -/// Pre-computed KeyValuePair entries for container code generation. -/// Whether there are any open generic registrations. -/// Whether there are any keyed service registrations. -/// Service types with multiple implementations for IEnumerable resolution. -/// Pre-computed registrations for each collection service type. -/// Singletons in reverse order for disposal (excluding open generics). -/// Scoped services in reverse order for disposal (excluding open generics). -internal sealed record ContainerRegistrationGroups( - ImmutableEquatableDictionary<(string ServiceType, string? Key), ImmutableEquatableArray> ByServiceTypeAndKey, - ImmutableEquatableSet AllServiceTypes, - ImmutableEquatableArray Singletons, - ImmutableEquatableArray Scoped, - ImmutableEquatableArray Transients, - ImmutableEquatableArray EagerSingletons, - ImmutableEquatableArray EagerScoped, - ImmutableEquatableArray LazyEntries, - ImmutableEquatableArray FuncEntries, - ImmutableEquatableArray KvpEntries, - bool HasOpenGenerics, - bool HasKeyedServices, - ImmutableEquatableArray CollectionServiceTypes, - ImmutableEquatableDictionary> CollectionRegistrations, - ImmutableEquatableArray ReversedSingletonsForDisposal, - ImmutableEquatableArray ReversedScopedForDisposal); - -/// -/// A registration with pre-computed field and method names for efficient code generation. -/// -/// The original service registration model. -/// The pre-computed field name for storing the service instance. -/// The pre-computed resolver method name. -/// Whether this registration should be eagerly resolved during container/scope construction. -/// Whether this registration has async initialization members (pre-computed from ). -internal readonly record struct CachedRegistration( - ServiceRegistrationModel Registration, - string FieldName, - string ResolverMethodName, - bool IsEager, - bool IsAsyncInit); - -/// -/// Represents a Lazy resolver entry for container code generation. -/// -/// The fully-qualified inner service type name. -/// The method name of the inner service resolver. -/// The field name for storing the wrapper instance. -internal readonly record struct ContainerLazyEntry( - string InnerServiceTypeName, - string ResolverMethodName, - string FieldName); - -/// -/// Represents a Func resolver entry for container code generation. -/// -/// The fully-qualified inner service type name. -/// The method name of the inner service resolver. -/// The field name for storing the wrapper instance. -internal readonly record struct ContainerFuncEntry( - string InnerServiceTypeName, - string ResolverMethodName, - string FieldName); - -/// -/// Represents a KeyValuePair resolver entry for container code generation. -/// -/// The fully-qualified key type name. -/// The fully-qualified value type name. -/// The key literal expression. -/// The method name of the value service resolver. -/// The method name for this KVP resolver. -internal readonly record struct ContainerKvpEntry( - string KeyTypeName, - string ValueTypeName, - string KeyExpr, - string ResolverMethodName, - string KvpResolverMethodName); diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs index 9198af8..1267ad3 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/IocFeatures.cs @@ -1,4 +1,3 @@ -#nullable enable namespace SourceGen.Ioc.SourceGenerator.Models; diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/MsBuildProperties.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/MsBuildProperties.cs index 07daf5f..567330c 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/MsBuildProperties.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/MsBuildProperties.cs @@ -3,5 +3,4 @@ namespace SourceGen.Ioc.SourceGenerator.Models; internal sealed record MsBuildProperties( string? RootNamespace, string? CustomIocName, - ServiceLifetime? DefaultLifetime, IocFeatures Features); diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs deleted file mode 100644 index 7360457..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TransformExtensions.cs +++ /dev/null @@ -1,1734 +0,0 @@ -ο»Ώnamespace SourceGen.Ioc.SourceGenerator.Models; - -internal static class TransformExtensions -{ - private static string GetNameWithoutGeneric(string typeName) - { - int angleIndex = typeName.IndexOf('<'); - return angleIndex > 0 ? typeName[..angleIndex] : typeName; - } - - extension(ITypeSymbol typeSymbol) - { - public TypeData GetTypeData( - bool extractConstructorParams = false, - bool extractHierarchy = false, - HashSet? visited = null) - { - if(typeSymbol is INamedTypeSymbol namedTypeSymbol) - return namedTypeSymbol.GetTypeData(extractConstructorParams, extractHierarchy, visited); - - // Handle array types specially - extract element type information - if(typeSymbol is IArrayTypeSymbol arrayTypeSymbol) - return arrayTypeSymbol.GetTypeData(extractConstructorParams, extractHierarchy, visited); - - var name = typeSymbol.FullyQualifiedName; - return TypeData.CreateSimple(name); - } - } - - extension(INamedTypeSymbol typeSymbol) - { - /// - /// Gets the type data for this type symbol. - /// - /// Whether to extract constructor parameters recursively. - /// Whether to extract all interfaces and base classes. - /// Set of visited types to prevent infinite recursion during constructor parameter extraction. - /// Optional semantic model for resolving nameof() expressions. Only used for top-level extraction, not passed to recursive calls. - /// Whether to extract injection members (properties, fields, methods with [IocInject] attributes). Used for decorators. - public TypeData GetTypeData( - bool extractConstructorParams = false, - bool extractHierarchy = false, - HashSet? visited = null, - SemanticModel? semanticModel = null, - bool extractInjectionMembers = false) - { - visited = extractConstructorParams - ? (visited ?? new(SymbolEqualityComparer.Default)) - : null; - - // Build type name - for unbound generics, use actual type parameter names - var typeName = typeSymbol.BuildTypeName(); - - // Extract type parameters with full constraints - ImmutableEquatableArray? typeParameters = null; - if(typeSymbol.IsGenericType && typeSymbol.TypeArguments.Length > 0) - { - typeParameters = typeSymbol.ExtractTypeParameters(extractConstraints: true, depth: 0); - } - - ImmutableEquatableArray? constructorParams = null; - bool hasInjectConstructor = false; - if(extractConstructorParams && visited is not null) - { - // Pass semanticModel only for top-level extraction - // Recursive calls from within ExtractConstructorParametersWithInfo do not receive semanticModel - // to avoid cross-compilation-unit issues and stack overflow - (constructorParams, hasInjectConstructor) = typeSymbol.ExtractConstructorParametersWithInfo(visited, semanticModel); - } - - // Extract injection members if requested (for decorators) - ImmutableEquatableArray? injectionMembers = null; - if(extractInjectionMembers) - { - injectionMembers = typeSymbol.ExtractInjectionMembersForDecorator(semanticModel); - if(injectionMembers.Length == 0) - { - injectionMembers = null; - } - } - - // Extract hierarchy (interfaces and base classes) if requested - ImmutableEquatableArray? allInterfaces = null; - ImmutableEquatableArray? allBaseClasses = null; - if(extractHierarchy) - { - allInterfaces = typeSymbol.GetAllInterfaces(); - allBaseClasses = typeSymbol.GetAllBaseClasses(); - } - - // Error types (e.g., types from other source generators not yet resolved) - // should be treated as simple types, not open generics - if(typeSymbol.TypeKind == TypeKind.Error) - { - return TypeData.CreateSimple( - typeName, - constructorParams, - hasInjectConstructor, - injectionMembers, - allInterfaces, - allBaseClasses); - } - - // Check if this is a wrapper type (collection or non-collection) for DI - var nameWithoutGeneric = GetNameWithoutGeneric(typeName); - var wrapperKind = typeSymbol.GetWrapperKind(nameWithoutGeneric); - - // Downgrade nested Task or Wrapper shapes to WrapperKind.None. - // These shapes are not supported by the spec and fall back to IServiceProvider resolution. - if(wrapperKind == WrapperKind.Task && typeParameters is { Length: > 0 } && typeParameters[0].Type is WrapperTypeData) - { - wrapperKind = WrapperKind.None; - } - else if(wrapperKind is not WrapperKind.None - && typeParameters is not null - && typeParameters.Any(p => p.Type is TaskTypeData)) - { - wrapperKind = WrapperKind.None; - } - - if(wrapperKind is not WrapperKind.None) - { - return TypeData.CreateWrapper( - typeName, - nameWithoutGeneric, - typeSymbol.ContainsGenericParameters, - typeSymbol.Arity, - wrapperKind, - typeSymbol.IsNestedOpenGeneric, - typeParameters, - constructorParams, - hasInjectConstructor, - injectionMembers, - allInterfaces, - allBaseClasses); - } - - if(typeSymbol.ContainsGenericParameters || typeSymbol.Arity > 0 || typeParameters is { Length: > 0 }) - { - return TypeData.CreateGeneric( - typeName, - nameWithoutGeneric, - typeSymbol.ContainsGenericParameters, - typeSymbol.Arity, - typeSymbol.IsNestedOpenGeneric, - typeParameters, - constructorParams, - hasInjectConstructor, - injectionMembers, - allInterfaces, - allBaseClasses); - } - - return TypeData.CreateSimple( - typeName, - constructorParams, - hasInjectConstructor, - injectionMembers, - allInterfaces, - allBaseClasses); - } - - /// - /// Builds the fully qualified type name for this type symbol. - /// For unbound generic types, uses actual type parameter names instead of empty placeholders. - /// - public string BuildTypeName() - { - // For unbound generic types (e.g., typeof(Handler<,>)), we need to get the - // type parameter names from TypeParameters, not from FullyQualifiedName - // FullyQualifiedName returns "global::Ns.Handler<,>" but we need "global::Ns.Handler" - if(typeSymbol.IsUnboundGenericType && typeSymbol.TypeParametersSource.Length > 0) - { - var nameWithoutGeneric = GetNameWithoutGeneric(typeSymbol.FullyQualifiedName); - var typeParamNames = typeSymbol.TypeParametersSource.Select(tp => tp.Name); - return $"{nameWithoutGeneric}<{string.Join(", ", typeParamNames)}>"; - } - - return typeSymbol.FullyQualifiedName; - } - - /// - /// Core implementation for extracting type parameters. - /// - /// Whether to extract constraint types for each type parameter. - /// Current recursion depth to prevent infinite recursion. - /// An immutable array of type parameters with their resolved types. - public ImmutableEquatableArray ExtractTypeParameters(bool extractConstraints, int depth) - { - const int MaxDepth = 10; // Prevent infinite recursion for pathological cases - - var typeParams = typeSymbol.TypeParametersSource; - if(typeParams.Length == 0 || depth >= MaxDepth) - { - return []; - } - - var typeArgs = typeSymbol.TypeArguments; - List parameters = new(typeParams.Length); - - for(int i = 0; i < typeParams.Length; i++) - { - var typeParam = typeParams[i]; - var typeArg = i < typeArgs.Length ? typeArgs[i] : null; - - // Create TypeData for the type argument - var (typeData, allInterfaces) = typeParam.CreateTypeDataForTypeArg(typeArg, depth); - - // Add interfaces if extracted - if(allInterfaces is { Length: > 0 }) - { - typeData = typeData with { AllInterfaces = allInterfaces }; - } - - // Extract constraints only when requested (to avoid recursion in basic scenarios) - ImmutableEquatableArray? constraintTypes = null; - if(extractConstraints) - { - constraintTypes = typeParam.ConstraintTypes - .Select(ct => ct is INamedTypeSymbol namedCt - ? namedCt.CreateBasicTypeData(depth + 1) - : ct.TypeKind == TypeKind.TypeParameter - ? TypeData.CreateTypeParameter(ct.FullyQualifiedName) - : TypeData.CreateGeneric( - ct.FullyQualifiedName, - GetNameWithoutGeneric(ct.FullyQualifiedName), - ct.ContainsGenericParameters, - 0, - false)) - .ToImmutableEquatableArray(); - } - - parameters.Add(new TypeParameter( - typeParam.Name, - typeData, - constraintTypes, - typeParam.HasValueTypeConstraint, - typeParam.HasReferenceTypeConstraint, - typeParam.HasUnmanagedTypeConstraint, - typeParam.HasNotNullConstraint, - typeParam.HasConstructorConstraint)); - } - - return parameters.ToImmutableEquatableArray(); - } - - /// - /// Gets all interfaces implemented by a type. - /// Creates basic TypeData without recursive type parameter extraction to avoid circular dependencies. - /// - public ImmutableEquatableArray GetAllInterfaces() => - typeSymbol.AllInterfaces.Select(CreateBasicTypeData).ToImmutableEquatableArray(); - - /// - /// Gets all base classes of a type, excluding System.Object. - /// Creates basic TypeData without recursive type parameter extraction to avoid circular dependencies. - /// - public ImmutableEquatableArray GetAllBaseClasses() - { - List result = []; - var baseType = typeSymbol.BaseType; - while(baseType != null && baseType.SpecialType != SpecialType.System_Object) - { - result.Add(baseType.CreateBasicTypeData()); - baseType = baseType.BaseType; - } - return result.ToImmutableEquatableArray(); - } - - /// - /// Creates a basic TypeData with type parameters extracted recursively. - /// Does not extract constraint types to avoid circular dependencies. - /// - public TypeData CreateBasicTypeData(int depth = 0) - { - var typeName = typeSymbol.FullyQualifiedName; - var nameWithoutGeneric = GetNameWithoutGeneric(typeName); - - // Extract type parameters without constraints (to avoid recursion) - ImmutableEquatableArray? typeParameters = null; - if(typeSymbol.IsGenericType && typeSymbol.TypeArguments.Length > 0) - { - typeParameters = typeSymbol.ExtractTypeParameters(extractConstraints: false, depth); - } - - // Check if this is a wrapper type (collection or non-collection) - var wrapperKind = typeSymbol.GetWrapperKind(nameWithoutGeneric); - - if(wrapperKind is not WrapperKind.None) - { - return TypeData.CreateWrapper( - typeName, - nameWithoutGeneric, - typeSymbol.ContainsGenericParameters, - typeSymbol.Arity, - wrapperKind, - typeSymbol.IsNestedOpenGeneric, - typeParameters); - } - - if(typeSymbol.ContainsGenericParameters || typeSymbol.Arity > 0 || typeParameters is { Length: > 0 }) - { - return TypeData.CreateGeneric( - typeName, - nameWithoutGeneric, - typeSymbol.ContainsGenericParameters, - typeSymbol.Arity, - typeSymbol.IsNestedOpenGeneric, - typeParameters); - } - - return TypeData.CreateSimple(typeName); - } - - public IMethodSymbol? SpecifiedOrPrimaryOrMostParametersConstructor - { - get - { - IMethodSymbol? injectCtor = null; - IMethodSymbol? primaryCtor = null; - IMethodSymbol? bestCtor = null; - int maxParameters = -1; - foreach(var ctor in typeSymbol.Constructors) - { - if(ctor.IsImplicitlyDeclared) - continue; - - if(ctor.IsStatic) - continue; - - if(ctor.DeclaredAccessibility is not (Accessibility.Public or Accessibility.Internal)) - continue; - - // IocInjectAttribute/InjectAttribute specified constructor - highest priority - if(ctor.GetAttributes().Any(attr => attr.AttributeClass?.IsInject == true)) - { - injectCtor = ctor; - continue; - } - - var syntaxRef = ctor.DeclaringSyntaxReferences.FirstOrDefault(); - // Primary constructor - second priority - if(syntaxRef?.GetSyntax() is TypeDeclarationSyntax) - { - primaryCtor = ctor; - continue; - } - - // Find constructor with most parameters - lowest priority - if(ctor.Parameters.Length > maxParameters) - { - maxParameters = ctor.Parameters.Length; - bestCtor = ctor; - } - } - // Return by priority: [Inject] > primary > most parameters - return injectCtor ?? primaryCtor ?? bestCtor; - } - } - - /// - /// Extracts constructor parameters from a type and indicates whether the constructor was selected by [Inject] attribute. - /// - /// Set of visited types to prevent infinite recursion. - /// Optional semantic model for resolving nameof() expressions in service keys. Only used for top-level extraction, not passed to recursive calls. - /// A tuple containing the constructor parameters and whether the constructor has [Inject] attribute. - public (ImmutableEquatableArray Parameters, bool HasInjectConstructor) ExtractConstructorParametersWithInfo( - HashSet? visited = null, - SemanticModel? semanticModel = null) - { - // Check if we've already visited this type to prevent infinite recursion - if(visited is not null && !visited.Add(typeSymbol)) - { - return ([], false); - } - - // Get the original definition for open generic types to access constructors - var typeToInspect = typeSymbol.IsGenericType && typeSymbol.IsDefinition - ? typeSymbol - : typeSymbol.OriginalDefinition ?? typeSymbol; - - // Get the constructor: [Inject] marked > primary constructor > most parameters - var constructor = typeToInspect.SpecifiedOrPrimaryOrMostParametersConstructor; - if(constructor is null) - { - return ([], false); - } - - // Check if the selected constructor has [IocInject] or [Inject] attribute - bool hasInjectConstructor = constructor.GetAttributes() - .Any(static attr => attr.AttributeClass?.IsInject == true); - - visited ??= new HashSet(SymbolEqualityComparer.Default); - List parameters = []; - foreach(var param in constructor.Parameters) - { - var paramType = param.Type; - - // Get TypeData using the unified method with recursive constructor extraction - // Also extract hierarchy (interfaces) for generic types to enable IEnumerable detection - var paramTypeData = paramType is INamedTypeSymbol namedParamType - ? namedParamType.GetTypeData(extractConstructorParams: true, extractHierarchy: namedParamType.IsGenericType, visited: visited) - : paramType.GetTypeData(); - - // Check if parameter type is nullable (e.g., IDependency?) - var isNullable = param.NullableAnnotation == NullableAnnotation.Annotated; - - // Check if parameter has an explicit default value (for skipping unresolvable parameters) - var hasDefaultValue = param.HasExplicitDefaultValue; - - // Check for [FromKeyedServices], [Inject], or [ServiceKey] attribute - // SemanticModel is used to resolve nameof() expressions for top-level parameters - var (serviceKey, hasInjectAttribute, hasServiceKeyAttribute, hasFromKeyedServicesAttribute) = param.GetServiceKeyAndAttributeInfo(semanticModel); - - // Get the C# code representation of the default value - var defaultValue = hasDefaultValue ? ToDefaultValueCodeString(param.ExplicitDefaultValue) : null; - - parameters.Add(new ParameterData(param.Name, paramTypeData, - IsNullable: isNullable, - HasDefaultValue: hasDefaultValue, - DefaultValue: defaultValue, - ServiceKey: serviceKey, - HasInjectAttribute: hasInjectAttribute, - HasServiceKeyAttribute: hasServiceKeyAttribute, - HasFromKeyedServicesAttribute: hasFromKeyedServicesAttribute)); - } - - return (parameters.ToImmutableEquatableArray(), hasInjectConstructor); - } - - /// - /// Gets the for the given type symbol. - /// Checks collection types first, then non-collection wrapper types. - /// - /// The type name without generic parameters. - /// The for this type. - public WrapperKind GetWrapperKind(string nameWithoutGeneric) - { - if(typeSymbol.TypeKind == TypeKind.Array) - return WrapperKind.Array; - - if(IsReadOnlyCollectionType(nameWithoutGeneric)) - return WrapperKind.ReadOnlyCollection; - - if(IsReadOnlyListType(nameWithoutGeneric)) - return WrapperKind.ReadOnlyList; - - if(IsCollectionType(nameWithoutGeneric)) - return WrapperKind.Collection; - - if(IsListType(nameWithoutGeneric)) - return WrapperKind.List; - - if(IsEnumerableType(nameWithoutGeneric)) - return WrapperKind.Enumerable; - - return GetNonCollectionWrapperKind(nameWithoutGeneric, typeSymbol.Arity); - } - - /// - /// Determines the for a given type name (without generic part). - /// Returns if the type is not a recognized non-collection wrapper type. - /// Collection types (IEnumerable, IReadOnlyCollection, etc.) are detected separately - /// in via GetWrapperKind. - /// - /// The type name without generic parameters. - /// The number of type parameters. Used to distinguish Task<T> (arity 1) from non-generic Task (arity 0). - public static WrapperKind GetNonCollectionWrapperKind(string nameWithoutGeneric, int arity) => nameWithoutGeneric switch - { - "global::System.Lazy" or "System.Lazy" or "Lazy" => WrapperKind.Lazy, - "global::System.Func" or "System.Func" or "Func" => WrapperKind.Func, - "global::System.Collections.Generic.IDictionary" or "System.Collections.Generic.IDictionary" or "IDictionary" - or "global::System.Collections.Generic.IReadOnlyDictionary" or "System.Collections.Generic.IReadOnlyDictionary" or "IReadOnlyDictionary" - or "global::System.Collections.Generic.Dictionary" or "System.Collections.Generic.Dictionary" or "Dictionary" - => WrapperKind.Dictionary, - "global::System.Collections.Generic.KeyValuePair" or "System.Collections.Generic.KeyValuePair" or "KeyValuePair" => WrapperKind.KeyValuePair, - ("global::System.Threading.Tasks.Task" or "System.Threading.Tasks.Task" or "Task") when arity == 1 => WrapperKind.Task, - _ => WrapperKind.None - }; - - public bool IsInject => - typeSymbol.Name is "IocInjectAttribute" or "InjectAttribute"; - - /// - /// Enumerates members (properties, fields, methods) marked with IocInjectAttribute/InjectAttribute. - /// This is a shared method used by both Analyzer (ServiceInfo) and Generator (RegistrationData). - /// - /// - /// The method filters members based on: - /// - Non-static members only - /// - Properties with a setter - /// - Non-readonly fields - /// - Ordinary methods that return (sync) or non-generic - /// (async, when AsyncMethodInject - /// feature is enabled), and are not generic - /// - /// - /// An enumerable of tuples containing the member symbol and its inject attribute. - /// Analyzer can use ISymbol directly; Generator can convert to InjectionMemberData. - /// - public IEnumerable<(ISymbol Member, AttributeData InjectAttribute)> GetInjectedMembers() - { - // For unbound generic types (e.g., LoggingDecorator<,>), we need to use OriginalDefinition - // to get the actual member declarations with their attributes - var typeToInspect = typeSymbol.IsUnboundGenericType ? typeSymbol.OriginalDefinition : typeSymbol; - - foreach(var member in typeToInspect.GetMembers()) - { - // Skip static members - if(member.IsStatic) - continue; - - // Check if the member has IocInjectAttribute/InjectAttribute (by name only) - var injectAttribute = member.GetAttributes() - .FirstOrDefault(static attr => attr.AttributeClass?.IsInject == true); - - if(injectAttribute is null) - continue; - - // Validate member is injectable based on type - var isInjectable = member switch - { - IPropertySymbol property => property.SetMethod is not null, - IFieldSymbol field => !field.IsReadOnly, - IMethodSymbol method => method.MethodKind == MethodKind.Ordinary - && (method.ReturnsVoid || RoslynExtensions.IsNonGenericTaskReturnType(method)) - && !method.IsGenericMethod, - _ => false - }; - - if(isInjectable) - { - yield return (member, injectAttribute); - } - } - } - } - - extension(IParameterSymbol param) - { - /// - /// Gets the service key, injection attribute info, and [ServiceKey]/[FromKeyedServices] attribute from a parameter. - /// [FromKeyedServices] takes precedence over [Inject] for service key resolution. - /// HasInjectAttribute is only true for [Inject] attribute (not [FromKeyedServices], which MS.DI handles automatically). - /// HasServiceKeyAttribute indicates the parameter is marked with [ServiceKey] from Microsoft.Extensions.DependencyInjection. - /// HasFromKeyedServicesAttribute indicates the parameter is marked with [FromKeyedServices] from Microsoft.Extensions.DependencyInjection. - /// - /// A tuple containing the service key (if any), whether the parameter has [Inject] attribute, [ServiceKey] attribute, and [FromKeyedServices] attribute. - public (string? ServiceKey, bool HasInjectAttribute, bool HasServiceKeyAttribute, bool HasFromKeyedServicesAttribute) GetServiceKeyAndAttributeInfo(SemanticModel? semanticModel = null) - { - string? serviceKey = null; - bool hasInjectAttribute = false; - bool hasServiceKeyAttribute = false; - bool hasFromKeyedServicesAttribute = false; - - foreach(var attribute in param.GetAttributes()) - { - var attrClass = attribute.AttributeClass; - if(attrClass is null) - continue; - - // Check for Microsoft.Extensions.DependencyInjection.ServiceKeyAttribute - if(attrClass.Name == "ServiceKeyAttribute" - && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") - { - hasServiceKeyAttribute = true; - continue; - } - - // Check for Microsoft.Extensions.DependencyInjection.FromKeyedServicesAttribute (higher priority for key) - // Note: [FromKeyedServices] is handled by MS.DI automatically, so we don't set hasInjectAttribute - if(attrClass.Name == "FromKeyedServicesAttribute" - && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") - { - hasFromKeyedServicesAttribute = true; - // The key is the first constructor argument - if(attribute.ConstructorArguments.Length > 0) - { - var keyArg = attribute.ConstructorArguments[0]; - if(!keyArg.IsNull && keyArg.Value is not null) - { - serviceKey = keyArg.GetPrimitiveConstantString(); - } - } - // [FromKeyedServices] found, but continue to check for [Inject] as well - continue; - } - - // Check for IocInjectAttribute/InjectAttribute (by name only, to support third-party attributes) - if(attrClass.IsInject) - { - hasInjectAttribute = true; - // Only use [Inject] key if no [FromKeyedServices] key was found - if(serviceKey is null) - { - var (key, _, _) = attribute.GetKeyInfo(semanticModel); - serviceKey = key; - } - } - } - return (serviceKey, hasInjectAttribute, hasServiceKeyAttribute, hasFromKeyedServicesAttribute); - } - } - - extension(ITypeParameterSymbol typeParam) - { - /// - /// Creates TypeData for a type argument, with optional interface extraction. - /// - public (TypeData TypeData, ImmutableEquatableArray? AllInterfaces) CreateTypeDataForTypeArg( - ITypeSymbol? typeArg, - int depth) - { - if(typeArg is INamedTypeSymbol namedArg && typeArg.TypeKind != TypeKind.TypeParameter) - { - // For concrete types, recursively extract type parameters and interfaces - var typeData = namedArg.CreateBasicTypeData(depth + 1); - var allInterfaces = namedArg.AllInterfaces.Length > 0 - ? namedArg.AllInterfaces.Select(CreateInterfaceTypeData).ToImmutableEquatableArray() - : null; - return (typeData, allInterfaces); - } - - if(typeArg is not null) - { - var argName = typeArg.FullyQualifiedName; - TypeData typeData = typeArg.TypeKind == TypeKind.TypeParameter - ? TypeData.CreateTypeParameter(argName) - : TypeData.CreateGeneric( - argName, - GetNameWithoutGeneric(argName), - typeArg.ContainsGenericParameters, - 0); - return (typeData, null); - } - - // No type argument available, this is a type parameter placeholder - return (TypeData.CreateTypeParameter(typeParam.Name), null); - - // Creates a simple TypeData for an interface type. - static TypeData CreateInterfaceTypeData(INamedTypeSymbol iface) => - iface.IsGenericType || iface.Arity > 0 - ? TypeData.CreateGeneric( - iface.FullyQualifiedName, - GetNameWithoutGeneric(iface.FullyQualifiedName), - iface.ContainsGenericParameters, - iface.Arity, - false) - : TypeData.CreateSimple(iface.FullyQualifiedName); - } - } - - extension(IArrayTypeSymbol arrayTypeSymbol) - { - public TypeData GetTypeData( - bool extractConstructorParams = false, - bool extractHierarchy = false, - HashSet? visited = null) - { - var elementType = arrayTypeSymbol.ElementType; - var typeName = arrayTypeSymbol.FullyQualifiedName; - - // For arrays, create TypeData with element type as a pseudo-TypeParameter - // This allows TryGetArrayElementType to extract the element type - ImmutableEquatableArray typeParameters; - if(elementType is INamedTypeSymbol namedElementType) - { - var elementTypeData = namedElementType.GetTypeData(extractConstructorParams, extractHierarchy, visited); - typeParameters = [new TypeParameter("T", elementTypeData)]; - } - else - { - var elementTypeName = elementType.FullyQualifiedName; - TypeData elementTypeData = elementType.TypeKind == TypeKind.TypeParameter - ? TypeData.CreateTypeParameter(elementTypeName) - : TypeData.CreateGeneric( - elementTypeName, - GetNameWithoutGeneric(elementTypeName), - elementType.ContainsGenericParameters, - 0); - typeParameters = [new TypeParameter("T", elementTypeData)]; - } - - return TypeData.CreateWrapper( - typeName, - typeName, // For arrays, use full name as NameWithoutGeneric - elementType.ContainsGenericParameters, - GenericArity: 1, // Arrays have one "type parameter" (the element type) - WrapperKind.Array, // Arrays are collections - IsNestedOpenGeneric: false, - TypeParameters: typeParameters); - } - } - - extension(AttributeData attribute) - { - /// - /// Gets an array of type symbols from a named argument. - /// - /// The name of the named argument. - /// Whether to extract constructor parameters. - /// Whether to extract injection members (for decorators). - public ImmutableEquatableArray GetTypeArrayArgument( - string name, - bool extractConstructorParams = false, - bool extractInjectionMembers = false) - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key.Equals(name, StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) - { - List result = []; - foreach(var value in namedArg.Value.Values) - { - if(value.Value is INamedTypeSymbol namedTypeSymbol) - { - result.Add(namedTypeSymbol.GetTypeData( - extractConstructorParams, - extractHierarchy: false, - visited: null, - semanticModel: null, - extractInjectionMembers)); - } - else if(value.Value is ITypeSymbol typeSymbol) - { - result.Add(typeSymbol.GetTypeData(extractConstructorParams)); - } - } - return result.ToImmutableEquatableArray(); - } - } - - return []; - } - - /// - /// Gets an array of type symbols from an attribute constructor argument of type params Type[]. - /// This is the constructor-argument counterpart to , and is used when - /// service types are supplied positionally to the attribute constructor instead of via a named Type[] argument. - /// - /// - /// This method scans the attribute's constructor arguments for an array (or params) argument that contains - /// type values (for example, params Type[] serviceTypes) and converts those instances - /// to . It skips non-type arguments, such as ServiceLifetime enum values. - /// - public ImmutableEquatableArray GetTypeArrayFromConstructorArgument(bool extractConstructorParams = false) - { - foreach(var ctorArg in attribute.ConstructorArguments) - { - // Look for an array argument containing type values - if(ctorArg.Kind == TypedConstantKind.Array && !ctorArg.IsNull) - { - List result = []; - foreach(var value in ctorArg.Values) - { - if(value.Value is ITypeSymbol typeSymbol) - { - result.Add(typeSymbol.GetTypeData(extractConstructorParams)); - } - } - - // Only return if we found type values - if(result.Count > 0) - return result.ToImmutableEquatableArray(); - } - } - - return []; - } - - public (bool HasArg, ServiceLifetime Lifetime) TryGetLifetime() - { - // First, check if lifetime is passed as a constructor argument (for generic attributes like IoCRegisterAttribute(ServiceLifetime.Scoped)) - foreach(var ctorArg in attribute.ConstructorArguments) - { - if(ctorArg.Type?.Name == nameof(ServiceLifetime) && ctorArg.Value is int lifetimeValue) - { - return (true, (ServiceLifetime)lifetimeValue); - } - } - - // Fall back to named argument - var (hasArg, val) = attribute.TryGetNamedArgument("Lifetime", 2); // Default is ServiceLifetime.Transient - return (hasArg, (ServiceLifetime)val); - } - - public (bool HasArg, bool Value) TryGetRegisterAllInterfaces() => - attribute.TryGetNamedArgument("RegisterAllInterfaces", false); - - public (bool HasArg, bool Value) TryGetRegisterAllBaseClasses() => - attribute.TryGetNamedArgument("RegisterAllBaseClasses", false); - - /// - /// Gets the service types from the attribute. - /// This method checks both named arguments and constructor arguments for service types. - /// - /// - /// The method first checks for a named argument "ServiceTypes" (e.g., ServiceTypes = [typeof(IService)]). - /// If not found, it checks constructor arguments for an array of types (e.g., params Type[] serviceTypes). - /// - public ImmutableEquatableArray GetServiceTypes() - { - // First, try to get from named argument - var namedResult = attribute.GetTypeArrayArgument("ServiceTypes"); - if(namedResult.Length > 0) - return namedResult; - - // Fall back to constructor argument (params Type[] serviceTypes) - return attribute.GetTypeArrayFromConstructorArgument(); - } - - /// - /// Gets the service types from generic attribute type parameters (e.g., IoCRegisterAttribute<T1, T2>). - /// - public ImmutableEquatableArray GetServiceTypesFromGenericAttribute() - { - var attrClass = attribute.AttributeClass; - if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) - return []; - - List result = []; - foreach(var typeArg in attrClass.TypeArguments) - { - if(typeArg is INamedTypeSymbol namedType) - { - result.Add(namedType.GetTypeData()); - } - } - return result.ToImmutableEquatableArray(); - } - - public ImmutableEquatableArray GetDecorators() => - attribute.GetTypeArrayArgument("Decorators", extractConstructorParams: true, extractInjectionMembers: true); - - /// - /// Gets the ImplementationTypes array from the attribute. - /// Extracts implementation types with constructor parameters and hierarchy information, - /// using the same parsing logic as IocRegisterAttribute. - /// - public ImmutableEquatableArray GetImplementationTypes() => - attribute.GetTypeArrayArgumentWithHierarchy("ImplementationTypes"); - - /// - /// Gets the ImplementationTypes array as INamedTypeSymbol from the attribute. - /// Used when full symbol access is needed for injection member extraction. - /// - public ImmutableEquatableArray GetImplementationTypeSymbols() => - attribute.GetTypeSymbolsFromNamedArgument("ImplementationTypes"); - - /// - /// Gets an array of type symbols from a named argument. - /// Used when full symbol access is needed for further analysis. - /// - public ImmutableEquatableArray GetTypeSymbolsFromNamedArgument(string name) - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key.Equals(name, StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) - { - List result = []; - foreach(var value in namedArg.Value.Values) - { - if(value.Value is INamedTypeSymbol namedTypeSymbol) - { - result.Add(namedTypeSymbol); - } - } - return result.ToImmutableEquatableArray(); - } - } - - return []; - } - - /// - /// Gets an array of type symbols from a named argument with full hierarchy extraction. - /// Used for ImplementationTypes where we need constructor params and all interfaces/base classes. - /// - public ImmutableEquatableArray GetTypeArrayArgumentWithHierarchy(string name) - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key.Equals(name, StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) - { - List result = []; - foreach(var value in namedArg.Value.Values) - { - if(value.Value is INamedTypeSymbol namedTypeSymbol) - { - // Extract with constructor params and hierarchy, same as IocRegisterAttribute - result.Add(namedTypeSymbol.GetTypeData(extractConstructorParams: true, extractHierarchy: true)); - } - } - return result.ToImmutableEquatableArray(); - } - } - - return []; - } - - /// - /// Gets the Tags array from the attribute. - /// - public ImmutableEquatableArray GetTags() - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key.Equals("Tags", StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) - { - List result = []; - foreach(var value in namedArg.Value.Values) - { - if(value.Value is string tag) - { - result.Add(tag); - } - } - return result.ToImmutableEquatableArray(); - } - } - - return []; - } - - /// - /// Checks if the attribute has Factory or Instance specified. - /// - /// A tuple indicating whether Factory and/or Instance are specified. - public (bool HasFactory, bool HasInstance) HasFactoryOrInstance() - { - bool hasFactory = false; - bool hasInstance = false; - - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key == "Factory" && !namedArg.Value.IsNull) - { - hasFactory = true; - } - else if(namedArg.Key == "Instance" && !namedArg.Value.IsNull) - { - hasInstance = true; - } - - // Early exit if both found - if(hasFactory && hasInstance) - break; - } - - return (hasFactory, hasInstance); - } - - /// - /// Gets the target type from an IoCRegisterForAttribute. - /// For non-generic variant, extracts from constructor argument. - /// For generic variant (IoCRegisterForAttribute<T>), extracts from type parameter. - /// - /// The target type symbol, or null if not found. - public INamedTypeSymbol? GetTargetTypeFromRegisterForAttribute() - { - var attributeClass = attribute.AttributeClass; - if(attributeClass is null) - return null; - - // For generic IoCRegisterForAttribute, get T from type arguments - if(attributeClass.IsGenericType && attributeClass.TypeArguments.Length > 0) - { - return attributeClass.TypeArguments[0] as INamedTypeSymbol; - } - - // For non-generic IoCRegisterForAttribute, get from constructor argument - if(attribute.ConstructorArguments.Length > 0 && - attribute.ConstructorArguments[0].Value is INamedTypeSymbol targetType) - { - return targetType; - } - - return null; - } - - /// - /// Gets all key-related information from the attribute in a single pass: - /// key string, key type, and key value type symbol (with optional nameof() resolution). - /// - /// Optional semantic model to resolve nameof() expression types and full access paths. - /// - /// A tuple containing: - /// - Key: The key string, or null if no key is specified. - /// - KeyType: The key type (0 = Value, 1 = Csharp). - /// - KeyValueTypeSymbol: The type symbol of the key value, or null when the type cannot be determined. - /// - public (string? Key, int KeyType, ITypeSymbol? KeyValueTypeSymbol) GetKeyInfo(SemanticModel? semanticModel = null) - { - var keyType = attribute.GetNamedArgument("KeyType", 0); - var isCsharpKeyType = keyType == 1; - - // First, check if key is passed as a constructor argument (e.g., InjectAttribute(object key)) - if(attribute.ConstructorArguments.Length > 0) - { - var ctorArg = attribute.ConstructorArguments[0]; - // Skip if the first argument is a type, lifetime enum, or array (e.g., IoCRegisterDefaultsAttribute) - if(ctorArg.Type?.Name != nameof(ServiceLifetime) - && ctorArg.Kind != TypedConstantKind.Type - && ctorArg.Kind != TypedConstantKind.Array - && !ctorArg.IsNull) - { - if(isCsharpKeyType) - { - // Try to get original syntax for nameof() expressions with full access path resolution - var key = attribute.TryGetNameofFromConstructorArg(0, semanticModel) - ?? ctorArg.Value?.ToString(); - var keyValueType = TryResolveNameofTypeFromConstructorArg(attribute, 0, semanticModel); - return (key, keyType, keyValueType); - } - - // Value key: treat the primitive constant as CSharp code - return (ctorArg.GetPrimitiveConstantString(), 1, ctorArg.Type); - } - } - - // Fall back to named argument - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key != "Key") - continue; - - if(namedArg.Value.IsNull) - return (null, keyType, null); - - if(isCsharpKeyType) - { - // Try to get original syntax for nameof() expressions with full access path resolution - var key = attribute.TryGetNameof("Key", semanticModel) - ?? namedArg.Value.Value?.ToString(); - var keyValueType = TryResolveNameofTypeFromNamedArg(attribute, "Key", semanticModel); - return (key, keyType, keyValueType); - } - - // Value key: treat the primitive constant as CSharp code - return (namedArg.Value.GetPrimitiveConstantString(), 1, namedArg.Value.Type); - } - - return (null, keyType, null); - } - - /// - /// Tries to resolve the type of a nameof() expression in a constructor argument. - /// Returns null if the argument is not a nameof() expression or cannot be resolved. - /// - private static ITypeSymbol? TryResolveNameofTypeFromConstructorArg(AttributeData attr, int argumentIndex, SemanticModel? semanticModel) - { - if(semanticModel is null) - return null; - - var syntaxReference = attr.ApplicationSyntaxReference; - if(syntaxReference?.GetSyntax() is not AttributeSyntax attributeSyntax) - return null; - - var argumentList = attributeSyntax.ArgumentList; - if(argumentList is null || argumentList.Arguments.Count <= argumentIndex) - return null; - - var argument = argumentList.Arguments[argumentIndex]; - if(argument.NameEquals is not null) - return null; - - return ResolveNameofExpressionType(argument.Expression, semanticModel); - } - - /// - /// Tries to resolve the type of a nameof() expression in a named argument. - /// Returns null if the argument is not a nameof() expression or cannot be resolved. - /// - private static ITypeSymbol? TryResolveNameofTypeFromNamedArg(AttributeData attr, string argumentName, SemanticModel? semanticModel) - { - if(semanticModel is null) - return null; - - var syntaxReference = attr.ApplicationSyntaxReference; - if(syntaxReference?.GetSyntax() is not AttributeSyntax attributeSyntax) - return null; - - var argumentList = attributeSyntax.ArgumentList; - if(argumentList is null) - return null; - - foreach(var argument in argumentList.Arguments) - { - if(argument.NameEquals?.Name.Identifier.Text == argumentName) - { - return ResolveNameofExpressionType(argument.Expression, semanticModel); - } - } - - return null; - } - - /// - /// If the expression is a nameof() invocation, resolves the referenced symbol's type. - /// Returns null for non-nameof expressions. - /// - private static ITypeSymbol? ResolveNameofExpressionType(ExpressionSyntax expression, SemanticModel semanticModel) - { - if(expression is not InvocationExpressionSyntax invocation - || invocation.Expression is not IdentifierNameSyntax identifierName - || identifierName.Identifier.Text != "nameof" - || invocation.ArgumentList.Arguments.Count != 1) - { - return null; - } - - var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; - var symbolInfo = semanticModel.GetSymbolInfo(nameofArgument); - var symbol = symbolInfo.Symbol ?? symbolInfo.CandidateSymbols.FirstOrDefault(); - - return symbol switch - { - IFieldSymbol field => field.Type, - IPropertySymbol property => property.Type, - IMethodSymbol method => method.ReturnType, - ILocalSymbol local => local.Type, - IParameterSymbol param => param.Type, - _ => null, - }; - } - - /// - /// Extracts from the registration attribute's - /// GenericFactoryTypeMapping named property. - /// Used as a fallback when [IocGenericFactory] is not present on the factory method. - /// - /// The generic factory type mapping, or null if not specified or invalid. - public GenericFactoryTypeMapping? ExtractGenericFactoryMappingFromAttributeProperty() - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key != "GenericFactoryTypeMapping") - continue; - - if(namedArg.Value.Kind != TypedConstantKind.Array || namedArg.Value.IsNull) - return null; - - var typeArray = namedArg.Value.Values; - if(typeArray.Length < 2) - return null; - - if(typeArray[0].Value is not INamedTypeSymbol serviceTypeTemplate) - return null; - - var serviceTypeTemplateData = serviceTypeTemplate.GetTypeData(); - - var placeholderMap = new Dictionary(StringComparer.Ordinal); - for(int i = 1; i < typeArray.Length; i++) - { - if(typeArray[i].Value is ITypeSymbol placeholderType) - { - var placeholderTypeName = placeholderType.FullyQualifiedName; - if(placeholderMap.ContainsKey(placeholderTypeName)) - return null; // Duplicate placeholder - placeholderMap[placeholderTypeName] = i - 1; - } - } - - if(placeholderMap.Count != typeArray.Length - 1) - return null; - - return new GenericFactoryTypeMapping( - serviceTypeTemplateData, - placeholderMap.ToImmutableEquatableDictionary()); - } - - return null; - } - - /// - /// Gets the Factory method data from the attribute, including parameter and return type information. - /// When the resolved factory method is generic but has no [IocGenericFactory] attribute, - /// falls back to the GenericFactoryTypeMapping property on the registration attribute. - /// - /// Semantic model to resolve method symbols. - /// The factory method data, or null if not specified. - public FactoryMethodData? GetFactoryMethodData(SemanticModel semanticModel) - { - var syntaxReference = attribute.ApplicationSyntaxReference; - if(syntaxReference?.GetSyntax() is not AttributeSyntax attributeSyntax) - return null; - - var argumentList = attributeSyntax.ArgumentList; - if(argumentList is null) - return null; - - foreach(var argument in argumentList.Arguments) - { - if(argument.NameEquals?.Name.Identifier.Text != "Factory") - continue; - - // Check if the expression is a nameof() invocation - if(argument.Expression is InvocationExpressionSyntax invocation && - invocation.Expression is IdentifierNameSyntax identifierName && - identifierName.Identifier.Text == "nameof" && - invocation.ArgumentList.Arguments.Count == 1) - { - var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; - var methodSymbol = ResolveMethodSymbol(nameofArgument, semanticModel); - - if(methodSymbol is not null) - { - var factoryData = CreateFactoryMethodData(methodSymbol); - - // Fallback: if method is generic but has no [IocGenericFactory], check attribute's GenericFactoryTypeMapping - if(factoryData.GenericTypeMapping is null && methodSymbol.TypeParameters.Length > 0) - { - var mappingFromAttr = attribute.ExtractGenericFactoryMappingFromAttributeProperty(); - if(mappingFromAttr is not null) - factoryData = factoryData with { GenericTypeMapping = mappingFromAttr }; - } - - return factoryData; - } - - // Fallback: get path from nameof expression - var nameofPath = ResolveNameofExpression(nameofArgument, semanticModel) - ?? nameofArgument.ToFullString().Trim(); - return new FactoryMethodData(nameofPath, HasServiceProvider: true, HasKey: false, ReturnTypeName: null, AdditionalParameters: []); - } - - // String literal - cannot determine parameters, assume full signature - if(argument.Expression is LiteralExpressionSyntax literal && - literal.Token.Value is string literalPath) - { - return new FactoryMethodData(literalPath, HasServiceProvider: true, HasKey: false, ReturnTypeName: null, AdditionalParameters: []); - } - } - - return null; - } - - /// - /// Gets the Instance path from the attribute. - /// - /// Optional semantic model to resolve full access paths for nameof() expressions. - /// The static instance path (e.g., "MyService.Default"), or null if not specified. - public string? GetInstance(SemanticModel? semanticModel = null) - { - foreach(var namedArg in attribute.NamedArguments) - { - if(namedArg.Key == "Instance") - { - if(namedArg.Value.IsNull) - return null; - - // Try to get original syntax for nameof() expressions with full access path resolution - return attribute.TryGetNameof("Instance", semanticModel) - ?? namedArg.Value.Value?.ToString(); - } - } - - return null; - } - - /// - /// Determines if the attribute will cause registration of interfaces or base classes. - /// For open generic types, nested open generics are only a problem when registering interfaces/base classes. - /// - public bool WillRegisterInterfacesOrBaseClasses() - { - // Check if ServiceTypes is specified - var serviceTypes = attribute.GetServiceTypes(); - if(serviceTypes.Length > 0) - return true; - - // Check if RegisterAllInterfaces is true - var (hasRegisterAllInterfaces, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); - if(hasRegisterAllInterfaces && registerAllInterfaces) - return true; - - // Check if RegisterAllBaseClasses is true - var (hasRegisterAllBaseClasses, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); - if(hasRegisterAllBaseClasses && registerAllBaseClasses) - return true; - - // Only registering self, no interfaces/base classes - return false; - } - - /// - /// Extracts default settings from an IoCRegisterDefaultSettingsAttribute. - /// - /// Optional semantic model for resolving Factory method data. - /// The default settings model, or null if the attribute data is invalid. - public DefaultSettingsModel? ExtractDefaultSettings(SemanticModel? semanticModel = null) - { - if(attribute.ConstructorArguments.Length < 2) - return null; - if(attribute.ConstructorArguments[0].Value is not INamedTypeSymbol targetServiceType) - return null; - if(attribute.ConstructorArguments[1].Value is not int lifetime) - return null; - - var (_, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); - var (_, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); - var serviceTypes = attribute.GetServiceTypes(); - var typeData = targetServiceType.GetTypeData(); - var decorators = attribute.GetDecorators(); - var tags = attribute.GetTags(); - - // Get factory method data if semantic model is provided - FactoryMethodData? factory = null; - if(semanticModel is not null) - { - factory = attribute.GetFactoryMethodData(semanticModel); - } - - // Get implementation types with constructor params and hierarchy (same as IocRegisterAttribute) - var implementationTypes = attribute.GetImplementationTypes(); - - return new DefaultSettingsModel( - typeData, - (ServiceLifetime)lifetime, - registerAllInterfaces, - registerAllBaseClasses, - serviceTypes, - decorators, - tags, - factory, - implementationTypes); - } - - /// - /// Extracts default settings from a generic IoCRegisterDefaultsAttribute (e.g., IoCRegisterDefaultsAttribute<T>). - /// The target service type is specified via type parameter instead of constructor argument. - /// - /// Optional semantic model for resolving Factory method data. - /// The default settings model, or null if the attribute data is invalid. - public DefaultSettingsModel? ExtractDefaultSettingsFromGenericAttribute(SemanticModel? semanticModel = null) - { - var attrClass = attribute.AttributeClass; - if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) - return null; - - if(attrClass.TypeArguments[0] is not INamedTypeSymbol targetServiceType) - return null; - - // Lifetime is the first constructor argument for the generic version - if(attribute.ConstructorArguments.Length < 1) - return null; - if(attribute.ConstructorArguments[0].Value is not int lifetime) - return null; - - var (_, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); - var (_, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); - var serviceTypes = attribute.GetServiceTypes(); - var typeData = targetServiceType.GetTypeData(); - var decorators = attribute.GetDecorators(); - var tags = attribute.GetTags(); - - // Get factory method data if semantic model is provided - FactoryMethodData? factory = null; - if(semanticModel is not null) - { - factory = attribute.GetFactoryMethodData(semanticModel); - } - - // Get implementation types with constructor params and hierarchy (same as IocRegisterAttribute) - var implementationTypes = attribute.GetImplementationTypes(); - - return new DefaultSettingsModel( - typeData, - (ServiceLifetime)lifetime, - registerAllInterfaces, - registerAllBaseClasses, - serviceTypes, - decorators, - tags, - factory, - implementationTypes); - } - } - - extension(IMethodSymbol methodSymbol) - { - /// - /// Creates FactoryMethodData from a method symbol. - /// Analyzes factory method parameters: - /// - IServiceProvider: Will be passed the service provider directly - /// - [ServiceKey] attribute: Will be passed the registration key value - /// - Other parameters: Will be resolved from the service provider using the same logic as [IocInject] methods - /// Also extracts [IocGenericFactory] attribute if present for generic factory method support. - /// - public FactoryMethodData CreateFactoryMethodData() - { - var path = methodSymbol.FullAccessPath; - bool hasServiceProvider = false; - bool hasKey = false; - List? additionalParameters = null; - - foreach(var param in methodSymbol.Parameters) - { - var paramTypeName = param.Type.FullyQualifiedName; - - // Check for IServiceProvider - if(paramTypeName is "global::System.IServiceProvider" or "System.IServiceProvider") - { - hasServiceProvider = true; - continue; - } - - // Check for [ServiceKey] attribute - bool hasServiceKeyAttribute = false; - foreach(var attribute in param.GetAttributes()) - { - var attrClass = attribute.AttributeClass; - if(attrClass is null) - continue; - - if(attrClass.Name == "ServiceKeyAttribute" - && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") - { - hasServiceKeyAttribute = true; - hasKey = true; - break; - } - } - - // Skip [ServiceKey] parameters from additional parameters - if(hasServiceKeyAttribute) - continue; - - // Collect additional parameter info using the same logic as [IocInject] methods - var (serviceKey, hasInjectAttribute, _, hasFromKeyedServicesAttribute) = param.GetServiceKeyAndAttributeInfo(); - var parameterData = new ParameterData( - param.Name, - param.Type.GetTypeData(), - IsNullable: param.NullableAnnotation == NullableAnnotation.Annotated, - HasDefaultValue: param.HasExplicitDefaultValue, - DefaultValue: param.HasExplicitDefaultValue ? ToDefaultValueCodeString(param.ExplicitDefaultValue) : null, - ServiceKey: serviceKey, - HasInjectAttribute: hasInjectAttribute, - HasServiceKeyAttribute: false, // Already handled above - HasFromKeyedServicesAttribute: hasFromKeyedServicesAttribute); - - additionalParameters ??= []; - additionalParameters.Add(parameterData); - } - - // Always store the return type for runtime comparison - var returnTypeName = methodSymbol.ReturnType.FullyQualifiedName; - - // Extract [IocGenericFactory] attribute if present - var genericTypeMapping = methodSymbol.ExtractGenericFactoryMapping(); - var typeParameterCount = methodSymbol.TypeParameters.Length; - - return new FactoryMethodData( - path, - hasServiceProvider, - hasKey, - returnTypeName, - additionalParameters?.ToImmutableEquatableArray() ?? [], - genericTypeMapping, - typeParameterCount); - } - - /// - /// Extracts [IocGenericFactory] attribute from the method symbol and builds the type mapping. - /// - public GenericFactoryTypeMapping? ExtractGenericFactoryMapping() - { - // Only applicable to generic methods - if(methodSymbol.TypeParameters.Length == 0) - { - return null; - } - - // Find [IocGenericFactory] attribute - AttributeData? genericFactoryAttr = null; - foreach(var attr in methodSymbol.GetAttributes()) - { - var attrClass = attr.AttributeClass; - if(attrClass is null) - continue; - - var fullName = attrClass.ToDisplayString(); - if(fullName == Constants.IocGenericFactoryAttributeFullName) - { - genericFactoryAttr = attr; - break; - } - } - - if(genericFactoryAttr is null) - { - return null; - } - - // Extract GenericTypeMap array from constructor argument - // [IocGenericFactory(typeof(IRequestHandler>), typeof(int))] - // - First type: service type template with placeholders - // - Following types: map to factory method type parameters in order - if(genericFactoryAttr.ConstructorArguments.Length == 0) - { - return null; - } - - var firstArg = genericFactoryAttr.ConstructorArguments[0]; - if(firstArg.Kind != TypedConstantKind.Array || firstArg.Values.IsDefaultOrEmpty) - { - return null; - } - - var typeArray = firstArg.Values; - if(typeArray.Length < 2) - { - return null; // Need at least service type template and one placeholder mapping - } - - // First type is the service type template - if(typeArray[0].Value is not INamedTypeSymbol serviceTypeTemplate) - { - return null; - } - - var serviceTypeTemplateData = serviceTypeTemplate.GetTypeData(); - - // Build placeholder to type parameter index map - // Following types (index 1, 2, ...) map to factory method's type parameters (index 0, 1, ...) - var placeholderMap = new Dictionary(StringComparer.Ordinal); - var expectedPlaceholderCount = typeArray.Length - 1; - for(int i = 1; i < typeArray.Length; i++) - { - if(typeArray[i].Value is ITypeSymbol placeholderType) - { - var placeholderTypeName = placeholderType.FullyQualifiedName; - - // If the same placeholder type is used multiple times, the mapping is invalid - // because we cannot distinguish which type argument maps to which type parameter - if(placeholderMap.ContainsKey(placeholderTypeName)) - { - return null; - } - - // Map placeholder type to factory method's type parameter index (0-based) - placeholderMap[placeholderTypeName] = i - 1; - } - } - - // All placeholder types must be unique and present - if(placeholderMap.Count != expectedPlaceholderCount) - { - return null; - } - - return new GenericFactoryTypeMapping( - serviceTypeTemplateData, - placeholderMap.ToImmutableEquatableDictionary()); - } - } - - extension(INamedTypeSymbol typeSymbol) - { - /// - /// Extracts injection members (properties, fields, methods with [IocInject]/[Inject] attributes) from the type. - /// This is used for both regular registrations and decorators. - /// - /// Optional semantic model to resolve full access paths for nameof() expressions. - /// An array of injection member data. - public ImmutableEquatableArray ExtractInjectionMembersForDecorator(SemanticModel? semanticModel = null) - { - List? injectionMembers = null; - - foreach(var (member, injectAttribute) in typeSymbol.GetInjectedMembers()) - { - // Extract key information from IocInjectAttribute/InjectAttribute - var (key, _, _) = injectAttribute.GetKeyInfo(semanticModel); - - InjectionMemberData? memberData = member switch - { - IPropertySymbol property => CreateDecoratorPropertyInjection(property, key), - IFieldSymbol field => CreateDecoratorFieldInjection(field, key), - IMethodSymbol method => CreateDecoratorMethodInjection(method, key, semanticModel), - _ => null - }; - - if(memberData is not null) - { - injectionMembers ??= []; - injectionMembers.Add(memberData); - } - } - - return injectionMembers?.ToImmutableEquatableArray() ?? []; - } - - private static InjectionMemberData CreateDecoratorPropertyInjection(IPropertySymbol property, string? key) - { - var propertyType = property.Type.GetTypeData(); - var isNullable = property.NullableAnnotation == NullableAnnotation.Annotated; - - // Try to get the default value from property initializer - var (hasDefaultValue, defaultValue) = GetDecoratorPropertyDefaultValue(property); - - return new InjectionMemberData( - InjectionMemberType.Property, - property.Name, - propertyType, - null, - key, - isNullable, - hasDefaultValue, - defaultValue); - } - - private static InjectionMemberData CreateDecoratorFieldInjection(IFieldSymbol field, string? key) - { - var fieldType = field.Type.GetTypeData(); - var isNullable = field.NullableAnnotation == NullableAnnotation.Annotated; - - // Try to get the default value from field initializer - var (hasDefaultValue, defaultValue) = GetDecoratorFieldDefaultValue(field); - - return new InjectionMemberData( - InjectionMemberType.Field, - field.Name, - fieldType, - null, - key, - isNullable, - hasDefaultValue, - defaultValue); - } - - private static (bool HasDefaultValue, string? DefaultValue) GetDecoratorPropertyDefaultValue(IPropertySymbol property) - { - var syntaxRef = property.DeclaringSyntaxReferences.FirstOrDefault(); - if(syntaxRef?.GetSyntax() is not PropertyDeclarationSyntax propertySyntax) - return (false, null); - - var initializer = propertySyntax.Initializer; - if(initializer is null) - return (false, null); - - // Check if it's a null literal or null-forgiving expression (null!) - if(IsDecoratorNullExpression(initializer.Value)) - { - return (true, null); - } - - return (true, initializer.Value.ToString()); - } - - private static (bool HasDefaultValue, string? DefaultValue) GetDecoratorFieldDefaultValue(IFieldSymbol field) - { - var syntaxRef = field.DeclaringSyntaxReferences.FirstOrDefault(); - var syntax = syntaxRef?.GetSyntax(); - - // Field can be declared in VariableDeclaratorSyntax - EqualsValueClauseSyntax? initializer = syntax switch - { - VariableDeclaratorSyntax variableDeclarator => variableDeclarator.Initializer, - _ => null - }; - - if(initializer is null) - return (false, null); - - // Check if it's a null literal or null-forgiving expression (null!) - if(IsDecoratorNullExpression(initializer.Value)) - { - return (true, null); - } - - return (true, initializer.Value.ToString()); - } - - private static bool IsDecoratorNullExpression(ExpressionSyntax expression) - { - // Direct null literal - if(expression is LiteralExpressionSyntax literal && - literal.Kind() == SyntaxKind.NullLiteralExpression) - { - return true; - } - - // Null-forgiving expression: null! - if(expression is PostfixUnaryExpressionSyntax postfix && - postfix.Kind() == SyntaxKind.SuppressNullableWarningExpression && - postfix.Operand is LiteralExpressionSyntax innerLiteral && - innerLiteral.Kind() == SyntaxKind.NullLiteralExpression) - { - return true; - } - - return false; - } - - private static InjectionMemberData CreateDecoratorMethodInjection(IMethodSymbol method, string? key, SemanticModel? semanticModel) - { - var parameters = method.Parameters - .Select(p => - { - var (serviceKey, hasInjectAttribute, hasServiceKeyAttribute, hasFromKeyedServicesAttribute) = p.GetServiceKeyAndAttributeInfo(semanticModel); - return new ParameterData( - p.Name, - p.Type.GetTypeData(), - IsNullable: p.NullableAnnotation == NullableAnnotation.Annotated, - HasDefaultValue: p.HasExplicitDefaultValue, - DefaultValue: p.HasExplicitDefaultValue ? DecoratorToDefaultValueCodeString(p.ExplicitDefaultValue) : null, - ServiceKey: serviceKey, - HasInjectAttribute: hasInjectAttribute, - HasServiceKeyAttribute: hasServiceKeyAttribute, - HasFromKeyedServicesAttribute: hasFromKeyedServicesAttribute); - }) - .ToImmutableEquatableArray(); - - return new InjectionMemberData( - InjectionMemberType.Method, - method.Name, - null, - parameters, - key); - } - - private static string? DecoratorToDefaultValueCodeString(object? value) - { - return value switch - { - null => null, - string s => $"\"{s}\"", - char c => $"'{c}'", - bool b => b ? "true" : "false", - _ => value.ToString() - }; - } - } -} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs index 2d0b1f1..0972129 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Models/TypeData.cs @@ -357,9 +357,14 @@ internal sealed record class TaskTypeData( ConstructorParameters, HasInjectConstructor, InjectionMembers, AllInterfaces, AllBaseClasses); +/// +/// Wrapper classification result containing wrapper kind and extracted element type. +/// +internal readonly record struct WrapperInfo(WrapperKind Kind, INamedTypeSymbol? ElementType); + /// /// Represents the kind of wrapper for DI injection purposes. -/// Each value has a corresponding sealed TypeData derived type. +/// Most values have a corresponding sealed TypeData derived type; analyzer-only kinds (e.g., ValueTask) do not. /// internal enum WrapperKind { @@ -427,7 +432,13 @@ internal enum WrapperKind /// Task<T> - async-initialized service wrapper. /// Resolved via an async resolver method that awaits async inject methods. ///
- Task + Task, + + /// + /// ValueTask<T> - async-initialized service wrapper. + /// Analyzer recognizes this wrapper kind for validation paths. + /// + ValueTask } internal static class TypeDataExtensions diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/CombineAndResolveClosedGenerics.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/CombineAndResolveClosedGenerics.cs similarity index 98% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/CombineAndResolveClosedGenerics.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/CombineAndResolveClosedGenerics.cs index 03b11ef..6449f5d 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/CombineAndResolveClosedGenerics.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/CombineAndResolveClosedGenerics.cs @@ -17,7 +17,7 @@ private static ImmutableEquatableArray CombineAndRe in ImmutableArray factoryBasedOpenGenericEntries, CancellationToken ct) { - var registrations = new List(); + List registrations = []; // Use List to allow multiple implementations per service type key // (e.g., GenericRequestHandler and GenericRequestHandler2 both implement IRequestHandler<,>) @@ -38,7 +38,7 @@ private static ImmutableEquatableArray CombineAndRe // Index open generic entries - store ALL implementations per service type key if(result.OpenGenericEntries.Length > 0) { - openGenericIndex ??= new Dictionary>(StringComparer.Ordinal); + openGenericIndex ??= new Dictionary>(result.OpenGenericEntries.Length, StringComparer.Ordinal); foreach(var entry in result.OpenGenericEntries) { if(!openGenericIndex.TryGetValue(entry.ServiceTypeKey, out var list)) @@ -53,7 +53,7 @@ private static ImmutableEquatableArray CombineAndRe // Collect closed generic dependencies from constructor parameters if(result.ClosedGenericDependencies.Length > 0) { - closedGenericDependencies ??= new Dictionary(StringComparer.Ordinal); + closedGenericDependencies ??= new Dictionary(result.ClosedGenericDependencies.Length, StringComparer.Ordinal); foreach(var dep in result.ClosedGenericDependencies) { if(!closedGenericDependencies.ContainsKey(dep.ClosedTypeName)) @@ -67,7 +67,7 @@ private static ImmutableEquatableArray CombineAndRe // Index factory-based open generic entries from IocRegisterDefaults with Factory if(factoryBasedOpenGenericEntries.Length > 0) { - openGenericIndex ??= new Dictionary>(StringComparer.Ordinal); + openGenericIndex ??= new Dictionary>(factoryBasedOpenGenericEntries.Length, StringComparer.Ordinal); foreach(var entry in factoryBasedOpenGenericEntries) { if(!openGenericIndex.TryGetValue(entry.ServiceTypeKey, out var list)) @@ -82,7 +82,7 @@ private static ImmutableEquatableArray CombineAndRe // Collect closed generic dependencies from GetService/GetRequiredService invocations if(serviceProviderInvocations.Length > 0) { - closedGenericDependencies ??= new Dictionary(StringComparer.Ordinal); + closedGenericDependencies ??= new Dictionary(serviceProviderInvocations.Length, StringComparer.Ordinal); foreach(var dep in serviceProviderInvocations) { if(!closedGenericDependencies.ContainsKey(dep.ClosedTypeName)) @@ -98,7 +98,7 @@ private static ImmutableEquatableArray CombineAndRe { GenerateClosedGenericFactoryRegistrations( openGenericIndex, - closedGenericDependencies ?? new Dictionary(StringComparer.Ordinal), + closedGenericDependencies, registrations, ct); } @@ -112,7 +112,7 @@ private static ImmutableEquatableArray CombineAndRe ///
private static void GenerateClosedGenericFactoryRegistrations( Dictionary> openGenericIndex, - Dictionary closedGenericDependencies, + Dictionary? closedGenericDependencies, List registrations, CancellationToken ct) { @@ -135,7 +135,9 @@ private static void GenerateClosedGenericFactoryRegistrations( } // Use a queue for iterative processing of dependencies - var pendingDependencies = new Queue(closedGenericDependencies.Values); + var pendingDependencies = closedGenericDependencies is not null + ? new Queue(closedGenericDependencies.Values) + : new Queue(); var processedDependencies = new HashSet(StringComparer.Ordinal); // No dependencies to process - nothing to do @@ -774,7 +776,7 @@ private static List BuildClosedServiceTypesFromServiceTypeMap( TypeArgMap serviceTypeArgMap, TypeArgMap implTypeArgMap) { - var result = new List(); + var result = new List(openServiceTypes.Length); foreach(var openServiceType in openServiceTypes) { diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.Defaults.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.Defaults.cs new file mode 100644 index 0000000..6deb83b --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.Defaults.cs @@ -0,0 +1,106 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Gets additional service types from default settings and matched types. + /// + private static IEnumerable GetAdditionalServiceTypesFromDefaults( + DefaultSettingsModel? matchingDefault, + List matchedServiceTypes) + { + if(matchingDefault is not null) + { + foreach(var st in matchingDefault.ServiceTypes) + { + yield return st; + } + } + + foreach(var matchedType in matchedServiceTypes) + { + yield return matchedType; + } + } + + /// + /// Finds matching default settings from base classes and interfaces. + /// + /// The best matching default index, or -1 if none found. + private static int FindMatchingDefaults( + ImmutableEquatableArray baseClasses, + ImmutableEquatableArray interfaces, + DefaultSettingsMap defaultSettings, + List matchedDefaultIndices, + List matchedServiceTypes) + { + int bestDefaultIndex = -1; + + foreach(var candidate in baseClasses) + { + TryMatchDefaultSettings(candidate, defaultSettings, matchedDefaultIndices, matchedServiceTypes, ref bestDefaultIndex); + } + + foreach(var candidate in interfaces) + { + TryMatchDefaultSettings(candidate, defaultSettings, matchedDefaultIndices, matchedServiceTypes, ref bestDefaultIndex); + } + + return bestDefaultIndex; + } + + /// + /// Attempts to match a candidate type against default settings. + /// + private static void TryMatchDefaultSettings( + TypeData candidate, + DefaultSettingsMap defaultSettings, + List matchedDefaultIndices, + List matchedServiceTypes, + ref int bestDefaultIndex) + { + // Check exact matches + if(defaultSettings.TryGetExactMatches(candidate.Name, out var index) && !matchedDefaultIndices.Contains(index)) + { + matchedDefaultIndices.Add(index); + matchedServiceTypes.Add(candidate); + if(bestDefaultIndex < 0) bestDefaultIndex = index; + } + + // Check generic matches (only if type has generic parameters) + if(candidate is GenericTypeData candidateGeneric + && (candidateGeneric.IsOpenGeneric || candidate.Name != candidateGeneric.NameWithoutGeneric)) + { + if(defaultSettings.TryGetGenericMatches(candidateGeneric.NameWithoutGeneric, candidateGeneric.GenericArity, out var gIndex) + && !matchedDefaultIndices.Contains(gIndex)) + { + matchedDefaultIndices.Add(gIndex); + matchedServiceTypes.Add(candidate); + if(bestDefaultIndex < 0) bestDefaultIndex = gIndex; + } + } + } + + /// + /// Merges registration settings with default settings. + /// + private static (ServiceLifetime Lifetime, bool RegisterAllInterfaces, bool RegisterAllBaseClasses) MergeSettings( + RegistrationData registration, + DefaultSettingsModel? matchingDefault, + ServiceLifetime fallbackLifetime) + { + var lifetime = registration.HasExplicitLifetime + ? registration.Lifetime + : (matchingDefault?.Lifetime ?? fallbackLifetime); + + var registerAllInterfaces = registration.HasExplicitRegisterAllInterfaces + ? registration.RegisterAllInterfaces + : (matchingDefault?.RegisterAllInterfaces ?? registration.RegisterAllInterfaces); + + var registerAllBaseClasses = registration.HasExplicitRegisterAllBaseClasses + ? registration.RegisterAllBaseClasses + : (matchingDefault?.RegisterAllBaseClasses ?? registration.RegisterAllBaseClasses); + + return (lifetime, registerAllInterfaces, registerAllBaseClasses); + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.Parameters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.Parameters.cs new file mode 100644 index 0000000..0417d69 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.Parameters.cs @@ -0,0 +1,185 @@ +namespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + /// + /// Collects closed generic dependencies from a registration's constructor parameters, injection members, + /// factory method parameters, and closed decorators' constructor parameters and injection members. + /// + private static ImmutableEquatableArray CollectClosedGenericDependenciesFromRegistration( + RegistrationData registration, + ImmutableEquatableArray serviceRegistrations) + { + var constructorParams = registration.ImplementationType.ConstructorParameters; + var injectionMembers = registration.InjectionMembers; + var factoryParams = registration.Factory?.AdditionalParameters; + + // Check if we have any decorators with constructor parameters or injection members in the service registrations + var hasDecoratorDependencies = false; + foreach(var reg in serviceRegistrations) + { + foreach(var decorator in reg.Decorators) + { + // Check constructor parameters (> 1 because first param is the decorated service) + if(decorator.ConstructorParameters is { Length: > 1 }) + { + hasDecoratorDependencies = true; + break; + } + // Check injection members + if(decorator.InjectionMembers is { Length: > 0 }) + { + hasDecoratorDependencies = true; + break; + } + } + if(hasDecoratorDependencies) break; + } + + // Early exit if no constructor params, no injection members, no factory params, and no decorator dependencies + if((constructorParams is null || constructorParams.Length == 0) + && injectionMembers.Length == 0 + && (factoryParams is null || factoryParams.Length == 0) + && !hasDecoratorDependencies) + { + return []; + } + + var dependencies = new List(); + var addedKeys = new HashSet(StringComparer.Ordinal); + + // Collect from constructor parameters + if(constructorParams is not null) + { + foreach(var param in constructorParams) + { + CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); + } + } + + // Collect from injection members (properties, fields, methods with [Inject] attribute) + foreach(var member in injectionMembers) + { + // For properties and fields, check the member type + if(member.Type is not null) + { + CollectClosedGenericDependencyFromType(member.Type, dependencies, addedKeys); + } + + // For methods, check each parameter type + if(member.Parameters is not null) + { + foreach(var param in member.Parameters) + { + CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); + } + } + } + + // Collect from factory method's additional parameters + if(factoryParams is not null) + { + foreach(var param in factoryParams) + { + CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); + } + } + + // Collect from closed decorators' constructor parameters and injection members + // These are decorators that have been closed (type parameters substituted) for specific service types + foreach(var reg in serviceRegistrations) + { + foreach(var decorator in reg.Decorators) + { + // Collect from constructor parameters (skip first parameter - it's the decorated service) + if(decorator.ConstructorParameters is { Length: > 1 }) + { + for(int i = 1; i < decorator.ConstructorParameters.Length; i++) + { + var param = decorator.ConstructorParameters[i]; + CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); + } + } + + // Collect from injection members (properties, fields, methods with [Inject] attribute) + if(decorator.InjectionMembers is { Length: > 0 }) + { + foreach(var member in decorator.InjectionMembers) + { + // For properties and fields, check the member type + if(member.Type is not null) + { + CollectClosedGenericDependencyFromType(member.Type, dependencies, addedKeys); + } + + // For methods, check each parameter type + if(member.Parameters is not null) + { + foreach(var param in member.Parameters) + { + CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); + } + } + } + } + } + } + + return dependencies.ToImmutableEquatableArray(); + } + + /// + /// Collects closed generic dependency from a type and adds it to the dependencies list. + /// + /// The type to check for closed generic dependencies. + /// The list to add dependencies to. + /// Set of already added dependency keys to avoid duplicates. + private static void CollectClosedGenericDependencyFromType( + TypeData paramType, + List dependencies, + HashSet addedKeys) + { + // Extract inner types from wrapper types for closed generic dependency discovery. + // For example, Lazy> -> extract IHandler as a dependency. + var innerType = paramType switch + { + LazyTypeData l => l.InstanceType, + FuncTypeData f => f.ReturnType, + DictionaryTypeData d => d.ValueType, + KeyValuePairTypeData k => k.ValueType, + _ => (TypeData?)null + }; + if(innerType is not null) + { + // Recursively collect from the inner type (handles nested wrappers like Lazy>) + CollectClosedGenericDependencyFromType(innerType, dependencies, addedKeys); + } + + // Check if this is any enumerable-compatible type and extract element type for closed generic dependency. + // Collection wrappers use ElementType directly via CollectionWrapperTypeData; + // other types check direct generic type name and AllInterfaces for IEnumerable implementation. + var elementType = paramType switch + { + CollectionWrapperTypeData c => c.ElementType, + _ => paramType.TryGetEnumerableElementType() + }; + if(elementType is not null) + { + // Recursively collect from the element type (handles nested wrappers like IEnumerable>>) + CollectClosedGenericDependencyFromType(elementType, dependencies, addedKeys); + } + + // Check if this is a closed generic type (has generic arguments but is not open generic) + if(paramType is GenericTypeData { GenericArity: > 0, IsOpenGeneric: false, IsNestedOpenGeneric: false } genericParamType) + { + // Add the original type as a dependency (skip arrays as they don't need registration) + if(!paramType.IsArrayType && addedKeys.Add(paramType.Name)) + { + dependencies.Add(new ClosedGenericDependency( + paramType.Name, + paramType, + genericParamType.NameWithoutGeneric)); + } + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/ProcessSingleRegistration.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.cs similarity index 62% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/ProcessSingleRegistration.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.cs index 0e401e9..7aa8645 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/ProcessSingleRegistration.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Processing/ProcessSingleRegistration.cs @@ -9,7 +9,7 @@ partial class IocSourceGenerator ///
/// The registration data to process. /// The processed registration result. - private static BasicRegistrationResult ProcessSingleRegistrationFromDefaults(RegistrationData registration) + private static BasicRegistrationResult ProcessSingleRegistrationFromDefaults(RegistrationData registration, CancellationToken ct) { // Use the explicit settings from the registration (already set from defaults attribute) return ProcessRegistrationCore( @@ -20,7 +20,8 @@ private static BasicRegistrationResult ProcessSingleRegistrationFromDefaults(Reg decorators: registration.Decorators, tags: registration.Tags, factory: registration.Factory, - additionalServiceTypesFromDefaults: null); + additionalServiceTypesFromDefaults: null, + ct); } /// @@ -32,13 +33,15 @@ private static BasicRegistrationResult ProcessSingleRegistrationFromDefaults(Reg /// The processed registration result with all resolved settings. private static BasicRegistrationResult ProcessSingleRegistration( RegistrationData registration, - DefaultSettingsMap defaultSettings) + DefaultSettingsMap defaultSettings, + CancellationToken ct) { // Reusable buffers for default settings lookup var matchedDefaultIndices = new List(); var matchedServiceTypes = new List(); // Find matching default settings from base classes and interfaces + ct.ThrowIfCancellationRequested(); int bestDefaultIndex = FindMatchingDefaults( registration.AllBaseClasses, registration.AllInterfaces, @@ -77,28 +80,8 @@ private static BasicRegistrationResult ProcessSingleRegistration( decorators, tags, factory, - additionalServiceTypes); - } - - /// - /// Gets additional service types from default settings and matched types. - /// - private static IEnumerable GetAdditionalServiceTypesFromDefaults( - DefaultSettingsModel? matchingDefault, - List matchedServiceTypes) - { - if(matchingDefault is not null) - { - foreach(var st in matchingDefault.ServiceTypes) - { - yield return st; - } - } - - foreach(var matchedType in matchedServiceTypes) - { - yield return matchedType; - } + additionalServiceTypes, + ct); } /// @@ -112,18 +95,22 @@ private static BasicRegistrationResult ProcessRegistrationCore( ImmutableEquatableArray decorators, ImmutableEquatableArray tags, FactoryMethodData? factory, - IEnumerable? additionalServiceTypesFromDefaults) + IEnumerable? additionalServiceTypesFromDefaults, + CancellationToken ct) { + ct.ThrowIfCancellationRequested(); + var serviceTypesToRegister = new List { // Always register the implementation type itself registration.ImplementationType }; + var addedTypeNames = new HashSet(StringComparer.Ordinal) { registration.ImplementationType.Name }; // Add explicit service types from registration foreach(var st in registration.ServiceTypes) { - if(!serviceTypesToRegister.Contains(st)) + if(addedTypeNames.Add(st.Name)) { serviceTypesToRegister.Add(st); } @@ -134,7 +121,7 @@ private static BasicRegistrationResult ProcessRegistrationCore( { foreach(var st in additionalServiceTypesFromDefaults) { - if(!serviceTypesToRegister.Contains(st)) + if(addedTypeNames.Add(st.Name)) { serviceTypesToRegister.Add(st); } @@ -146,7 +133,7 @@ private static BasicRegistrationResult ProcessRegistrationCore( { foreach(var iface in registration.AllInterfaces) { - if(!serviceTypesToRegister.Contains(iface)) + if(addedTypeNames.Add(iface.Name)) { serviceTypesToRegister.Add(iface); } @@ -158,7 +145,7 @@ private static BasicRegistrationResult ProcessRegistrationCore( { foreach(var baseClass in registration.AllBaseClasses) { - if(!serviceTypesToRegister.Contains(baseClass)) + if(addedTypeNames.Add(baseClass.Name)) { serviceTypesToRegister.Add(baseClass); } @@ -193,87 +180,6 @@ private static BasicRegistrationResult ProcessRegistrationCore( closedGenericDependencies); } - /// - /// Finds matching default settings from base classes and interfaces. - /// - /// The best matching default index, or -1 if none found. - private static int FindMatchingDefaults( - ImmutableEquatableArray baseClasses, - ImmutableEquatableArray interfaces, - DefaultSettingsMap defaultSettings, - List matchedDefaultIndices, - List matchedServiceTypes) - { - int bestDefaultIndex = -1; - - foreach(var candidate in baseClasses) - { - TryMatchDefaultSettings(candidate, defaultSettings, matchedDefaultIndices, matchedServiceTypes, ref bestDefaultIndex); - } - - foreach(var candidate in interfaces) - { - TryMatchDefaultSettings(candidate, defaultSettings, matchedDefaultIndices, matchedServiceTypes, ref bestDefaultIndex); - } - - return bestDefaultIndex; - } - - /// - /// Attempts to match a candidate type against default settings. - /// - private static void TryMatchDefaultSettings( - TypeData candidate, - DefaultSettingsMap defaultSettings, - List matchedDefaultIndices, - List matchedServiceTypes, - ref int bestDefaultIndex) - { - // Check exact matches - if(defaultSettings.TryGetExactMatches(candidate.Name, out var index) && !matchedDefaultIndices.Contains(index)) - { - matchedDefaultIndices.Add(index); - matchedServiceTypes.Add(candidate); - if(bestDefaultIndex < 0) bestDefaultIndex = index; - } - - // Check generic matches (only if type has generic parameters) - if(candidate is GenericTypeData candidateGeneric - && (candidateGeneric.IsOpenGeneric || candidate.Name != candidateGeneric.NameWithoutGeneric)) - { - if(defaultSettings.TryGetGenericMatches(candidateGeneric.NameWithoutGeneric, candidateGeneric.GenericArity, out var gIndex) - && !matchedDefaultIndices.Contains(gIndex)) - { - matchedDefaultIndices.Add(gIndex); - matchedServiceTypes.Add(candidate); - if(bestDefaultIndex < 0) bestDefaultIndex = gIndex; - } - } - } - - /// - /// Merges registration settings with default settings. - /// - private static (ServiceLifetime Lifetime, bool RegisterAllInterfaces, bool RegisterAllBaseClasses) MergeSettings( - RegistrationData registration, - DefaultSettingsModel? matchingDefault, - ServiceLifetime fallbackLifetime) - { - var lifetime = registration.HasExplicitLifetime - ? registration.Lifetime - : (matchingDefault?.Lifetime ?? fallbackLifetime); - - var registerAllInterfaces = registration.HasExplicitRegisterAllInterfaces - ? registration.RegisterAllInterfaces - : (matchingDefault?.RegisterAllInterfaces ?? registration.RegisterAllInterfaces); - - var registerAllBaseClasses = registration.HasExplicitRegisterAllBaseClasses - ? registration.RegisterAllBaseClasses - : (matchingDefault?.RegisterAllBaseClasses ?? registration.RegisterAllBaseClasses); - - return (lifetime, registerAllInterfaces, registerAllBaseClasses); - } - /// /// Creates service registration models for each valid service type. /// @@ -659,185 +565,4 @@ private static ImmutableEquatableArray CreateOpenGenericEntrie return entries.ToImmutableEquatableArray(); } - - /// - /// Collects closed generic dependencies from a registration's constructor parameters, injection members, - /// factory method parameters, and closed decorators' constructor parameters and injection members. - /// - private static ImmutableEquatableArray CollectClosedGenericDependenciesFromRegistration( - RegistrationData registration, - ImmutableEquatableArray serviceRegistrations) - { - var constructorParams = registration.ImplementationType.ConstructorParameters; - var injectionMembers = registration.InjectionMembers; - var factoryParams = registration.Factory?.AdditionalParameters; - - // Check if we have any decorators with constructor parameters or injection members in the service registrations - var hasDecoratorDependencies = false; - foreach(var reg in serviceRegistrations) - { - foreach(var decorator in reg.Decorators) - { - // Check constructor parameters (> 1 because first param is the decorated service) - if(decorator.ConstructorParameters is { Length: > 1 }) - { - hasDecoratorDependencies = true; - break; - } - // Check injection members - if(decorator.InjectionMembers is { Length: > 0 }) - { - hasDecoratorDependencies = true; - break; - } - } - if(hasDecoratorDependencies) break; - } - - // Early exit if no constructor params, no injection members, no factory params, and no decorator dependencies - if((constructorParams is null || constructorParams.Length == 0) - && injectionMembers.Length == 0 - && (factoryParams is null || factoryParams.Length == 0) - && !hasDecoratorDependencies) - { - return []; - } - - var dependencies = new List(); - var addedKeys = new HashSet(StringComparer.Ordinal); - - // Collect from constructor parameters - if(constructorParams is not null) - { - foreach(var param in constructorParams) - { - CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); - } - } - - // Collect from injection members (properties, fields, methods with [Inject] attribute) - foreach(var member in injectionMembers) - { - // For properties and fields, check the member type - if(member.Type is not null) - { - CollectClosedGenericDependencyFromType(member.Type, dependencies, addedKeys); - } - - // For methods, check each parameter type - if(member.Parameters is not null) - { - foreach(var param in member.Parameters) - { - CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); - } - } - } - - // Collect from factory method's additional parameters - if(factoryParams is not null) - { - foreach(var param in factoryParams) - { - CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); - } - } - - // Collect from closed decorators' constructor parameters and injection members - // These are decorators that have been closed (type parameters substituted) for specific service types - foreach(var reg in serviceRegistrations) - { - foreach(var decorator in reg.Decorators) - { - // Collect from constructor parameters (skip first parameter - it's the decorated service) - if(decorator.ConstructorParameters is { Length: > 1 }) - { - for(int i = 1; i < decorator.ConstructorParameters.Length; i++) - { - var param = decorator.ConstructorParameters[i]; - CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); - } - } - - // Collect from injection members (properties, fields, methods with [Inject] attribute) - if(decorator.InjectionMembers is { Length: > 0 }) - { - foreach(var member in decorator.InjectionMembers) - { - // For properties and fields, check the member type - if(member.Type is not null) - { - CollectClosedGenericDependencyFromType(member.Type, dependencies, addedKeys); - } - - // For methods, check each parameter type - if(member.Parameters is not null) - { - foreach(var param in member.Parameters) - { - CollectClosedGenericDependencyFromType(param.Type, dependencies, addedKeys); - } - } - } - } - } - } - - return dependencies.ToImmutableEquatableArray(); - } - - /// - /// Collects closed generic dependency from a type and adds it to the dependencies list. - /// - /// The type to check for closed generic dependencies. - /// The list to add dependencies to. - /// Set of already added dependency keys to avoid duplicates. - private static void CollectClosedGenericDependencyFromType( - TypeData paramType, - List dependencies, - HashSet addedKeys) - { - // Extract inner types from wrapper types for closed generic dependency discovery. - // For example, Lazy> β†’ extract IHandler as a dependency. - var innerType = paramType switch - { - LazyTypeData l => l.InstanceType, - FuncTypeData f => f.ReturnType, - DictionaryTypeData d => d.ValueType, - KeyValuePairTypeData k => k.ValueType, - _ => (TypeData?)null - }; - if(innerType is not null) - { - // Recursively collect from the inner type (handles nested wrappers like Lazy>) - CollectClosedGenericDependencyFromType(innerType, dependencies, addedKeys); - } - - // Check if this is any enumerable-compatible type and extract element type for closed generic dependency. - // Collection wrappers use ElementType directly via CollectionWrapperTypeData; - // other types check direct generic type name and AllInterfaces for IEnumerable implementation. - var elementType = paramType switch - { - CollectionWrapperTypeData c => c.ElementType, - _ => paramType.TryGetEnumerableElementType() - }; - if(elementType is not null) - { - // Recursively collect from the element type (handles nested wrappers like IEnumerable>>) - CollectClosedGenericDependencyFromType(elementType, dependencies, addedKeys); - } - - // Check if this is a closed generic type (has generic arguments but is not open generic) - if(paramType is GenericTypeData { GenericArity: > 0, IsOpenGeneric: false, IsNestedOpenGeneric: false } genericParamType) - { - // Add the original type as a dependency (skip arrays as they don't need registration) - if(!paramType.IsArrayType && addedKeys.Add(paramType.Name)) - { - dependencies.Add(new ClosedGenericDependency( - paramType.Name, - paramType, - genericParamType.NameWithoutGeneric)); - } - } - } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.AttributeArguments.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.AttributeArguments.cs new file mode 100644 index 0000000..a87315e --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.AttributeArguments.cs @@ -0,0 +1,175 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + extension(AttributeData attributeData) + { + /// + /// Gets a named argument value from an attribute data. + /// + public T? GetNamedArgument(string name, T? defaultValue = default) + { + foreach(var namedArg in attributeData.NamedArguments) + { + if(namedArg.Key == name) + { + if(namedArg.Value.IsNull) + { + return defaultValue; + } + + return (T?)namedArg.Value.Value; + } + } + + return defaultValue; + } + + /// + /// Tries to get a named argument value from an attribute data.
+ /// If the argument is not found, returns HasArg = false. + ///
+ public (bool HasArg, T? Value) TryGetNamedArgument(string name, T? defaultValue = default) + { + foreach(var namedArg in attributeData.NamedArguments) + { + if(namedArg.Key == name) + { + if(namedArg.Value.IsNull) + { + return (true, defaultValue); + } + + return (true, (T?)namedArg.Value.Value); + } + } + + return (false, defaultValue); + } + + /// + /// Checks if a named argument was explicitly set in an attribute. + /// + public bool HasNamedArgument(string name) + { + foreach(var namedArg in attributeData.NamedArguments) + { + if(namedArg.Key == name) + { + return true; + } + } + + return false; + } + + /// + /// Tries to get the original syntax for a named argument, especially for expressions. + /// When a is provided, resolves the full access path of the referenced symbol. + /// + /// The name of the argument to find. + /// Optional semantic model to resolve full access paths for nameof() expressions. + /// The resolved symbol path if it's a expression; otherwise, null. + public string? TryGetNameof(string argumentName, SemanticModel? semanticModel = null) + { + var syntaxReference = attributeData.ApplicationSyntaxReference; + if(syntaxReference is null) + return null; + + var syntax = syntaxReference.GetSyntax(); + if(syntax is not AttributeSyntax attributeSyntax) + return null; + + var argumentList = attributeSyntax.ArgumentList; + if(argumentList is null) + return null; + + foreach(var argument in argumentList.Arguments) + { + // Check if this is a named argument with the correct name + if(argument.NameEquals?.Name.Identifier.Text == argumentName) + { + // Check if the expression is a nameof() invocation + if(argument.Expression is InvocationExpressionSyntax invocation && + invocation.Expression is IdentifierNameSyntax identifierName && + identifierName.Identifier.Text == "nameof") + { + // Extract the argument inside nameof() and return just that expression + if(invocation.ArgumentList.Arguments.Count == 1) + { + var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; + + // If semantic model is provided, try to resolve the full access path + if(semanticModel is not null) + { + var resolvedPath = ResolveNameofExpression(nameofArgument, semanticModel); + if(resolvedPath is not null) + return resolvedPath; + } + + return nameofArgument.ToFullString().Trim(); + } + } + } + } + + return null; + } + + /// + /// Tries to extract the expression from a constructor argument of an attribute. + /// + /// The index of the constructor argument to check. + /// Optional semantic model to resolve full access paths for nameof() expressions. + /// The resolved symbol path if it's a expression; otherwise, null. + public string? TryGetNameofFromConstructorArg(int argumentIndex, SemanticModel? semanticModel = null) + { + var syntaxReference = attributeData.ApplicationSyntaxReference; + if(syntaxReference is null) + return null; + + var syntax = syntaxReference.GetSyntax(); + if(syntax is not AttributeSyntax attributeSyntax) + return null; + + var argumentList = attributeSyntax.ArgumentList; + if(argumentList is null || argumentList.Arguments.Count <= argumentIndex) + return null; + + var argument = argumentList.Arguments[argumentIndex]; + + // Skip named arguments (they don't count as constructor arguments) + if(argument.NameEquals is not null) + return null; + + // Check if the expression is a nameof() invocation + if(argument.Expression is InvocationExpressionSyntax invocation && + invocation.Expression is IdentifierNameSyntax identifierName && + identifierName.Identifier.Text == "nameof") + { + // Extract the argument inside nameof() and return just that expression + if(invocation.ArgumentList.Arguments.Count == 1) + { + var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; + + // If semantic model is provided, try to resolve the full access path + if(semanticModel is not null) + { + var resolvedPath = ResolveNameofExpression(nameofArgument, semanticModel); + if(resolvedPath is not null) + return resolvedPath; + } + + return nameofArgument.ToFullString().Trim(); + } + } + + return null; + } + } + + extension(TypedConstant constant) + { + public string GetPrimitiveConstantString() => FormatPrimitiveConstant(constant.Type, constant.Value); + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Constructors.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Constructors.cs new file mode 100644 index 0000000..599aaff --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Constructors.cs @@ -0,0 +1,58 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + extension(INamedTypeSymbol typeSymbol) + { + public IMethodSymbol? PrimaryConstructor + { + get + { + foreach(var ctor in typeSymbol.Constructors) + { + if(ctor.IsImplicitlyDeclared) + continue; + + var syntaxRef = ctor.DeclaringSyntaxReferences.FirstOrDefault(); + if(syntaxRef?.GetSyntax() is TypeDeclarationSyntax) + return ctor; + } + + return null; + } + } + + public IMethodSymbol? PrimaryOrMostParametersConstructor + { + get + { + IMethodSymbol? bestCtor = null; + int maxParameters = -1; + foreach(var ctor in typeSymbol.Constructors) + { + if(ctor.IsImplicitlyDeclared) + continue; + + var syntaxRef = ctor.DeclaringSyntaxReferences.FirstOrDefault(); + // Primary constructor + if(syntaxRef?.GetSyntax() is TypeDeclarationSyntax) + return ctor; + + if(ctor.IsStatic) + continue; + + if(ctor.DeclaredAccessibility is not (Accessibility.Public or Accessibility.Internal)) + continue; + + // Find constructor with most parameters + if(ctor.Parameters.Length > maxParameters) + { + maxParameters = ctor.Parameters.Length; + bestCtor = ctor; + } + } + return bestCtor; + } + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.GenericTypes.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.GenericTypes.cs new file mode 100644 index 0000000..8d6b779 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.GenericTypes.cs @@ -0,0 +1,165 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + extension(ITypeSymbol typeSymbol) + { + public bool ContainsGenericParameters + { + get + { + if(typeSymbol.TypeKind is TypeKind.TypeParameter or TypeKind.Error) + { + return true; + } + + if(typeSymbol is INamedTypeSymbol namedTypeSymbol) + { + if(namedTypeSymbol.IsUnboundGenericType) + { + return true; + } + + for(; namedTypeSymbol != null; namedTypeSymbol = namedTypeSymbol.ContainingType) + { + if(namedTypeSymbol.TypeArguments.Any(arg => arg.ContainsGenericParameters)) + { + return true; + } + } + } + + return false; + } + } + + public INamedTypeSymbol? GetCompatibleGenericBaseType([NotNullWhen(true)] INamedTypeSymbol? genericType) + { + if(genericType is null) + { + return null; + } + + Debug.Assert(genericType.IsGenericTypeDefinition); + + if(genericType.TypeKind is TypeKind.Interface) + { + foreach(INamedTypeSymbol interfaceType in typeSymbol.AllInterfaces) + { + if(IsMatchingGenericType(interfaceType, genericType)) + { + return interfaceType; + } + } + } + + for(INamedTypeSymbol? current = typeSymbol as INamedTypeSymbol; current != null; current = current.BaseType) + { + if(IsMatchingGenericType(current, genericType)) + { + return current; + } + } + + return null; + + static bool IsMatchingGenericType(INamedTypeSymbol candidate, INamedTypeSymbol baseType) + { + return candidate.IsGenericType && SymbolEqualityComparer.Default.Equals(candidate.ConstructedFrom, baseType); + } + } + } + + extension(INamedTypeSymbol typeSymbol) + { + public bool IsGenericTypeDefinition => typeSymbol is { IsGenericType: true, IsDefinition: true }; + + /// + /// Gets the type parameters source for this type symbol. + /// For unbound generic types and constructed generic types, returns parameters from the original definition. + /// This allows matching type parameter names (TRequest, TResponse) with type arguments (TestRequest, List<string>). + /// + public ImmutableArray TypeParametersSource => + typeSymbol.IsGenericType + ? typeSymbol.OriginalDefinition?.TypeParameters ?? typeSymbol.TypeParameters + : typeSymbol.TypeParameters; + + /// + /// Determines whether the type is a nested open generic type. + /// A nested open generic is a generic type where any type argument itself contains generic parameters. + /// For example: IGeneric<IGeneric2<T>> is a nested open generic. + /// But IGeneric<T> or IGeneric<int> are not. + /// + public bool IsNestedOpenGeneric + { + get + { + if(!typeSymbol.IsGenericType) + { + return false; + } + + // For unbound generic types (e.g., IRepository<>), TypeArguments contains error types + // which should not be considered as nested open generics + if(typeSymbol.IsUnboundGenericType) + { + return false; + } + + // Check if any type argument contains generic parameters + foreach(var typeArg in typeSymbol.TypeArguments) + { + // If the type argument is not a simple type parameter (T, T1, etc.) + // but contains generic parameters, it's a nested open generic + if(typeArg.TypeKind != TypeKind.TypeParameter && typeArg.ContainsGenericParameters) + { + return true; + } + } + + return false; + } + } + } + + /// + /// Checks if sourceType is assignable to targetType. + /// Returns true if a value of sourceType can be assigned to a variable of targetType. + /// + /// The target type (e.g., parameter type). + /// The source type (e.g., value type). + /// True if sourceType is assignable to targetType. + public static bool IsAssignable(ITypeSymbol targetType, ITypeSymbol sourceType) + { + // Exact match + if(SymbolEqualityComparer.Default.Equals(targetType, sourceType)) + return true; + + // Handle nullable types - if target is nullable and source type is the underlying type + if(targetType.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T + && targetType is INamedTypeSymbol nullableTarget) + { + var underlyingType = nullableTarget.TypeArguments.FirstOrDefault(); + if(underlyingType is not null && SymbolEqualityComparer.Default.Equals(underlyingType, sourceType)) + return true; + } + + // Handle object type - any type is assignable to object + if(targetType.SpecialType is SpecialType.System_Object) + return true; + + // Handle inheritance - target type should be a base type or interface of source type + if(sourceType.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, targetType))) + return true; + + var currentBase = sourceType.BaseType; + while(currentBase is not null) + { + if(SymbolEqualityComparer.Default.Equals(currentBase, targetType)) + return true; + currentBase = currentBase.BaseType; + } + + return false; + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Identifiers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Identifiers.cs new file mode 100644 index 0000000..04041c8 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Identifiers.cs @@ -0,0 +1,88 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + /// + /// Converts a string to a safe C# namespace name. + /// Replaces invalid characters (like '-') with underscores, preserving '.' as namespace separator. + /// + /// The namespace name to convert. + /// A safe C# namespace name. + public static string GetSafeNamespace(string name) + { + if(string.IsNullOrWhiteSpace(name)) + return "Generated"; + + ReadOnlySpan nameSpan = name.AsSpan(); + + // Check if first char is digit (needs underscore prefix) + var needsPrefix = nameSpan.Length > 0 && char.IsDigit(nameSpan[0]); + var maxLength = nameSpan.Length + (needsPrefix ? 1 : 0); + + // Use stackalloc for small strings (up to 256 chars), otherwise use array pool + const int StackAllocThreshold = 256; + Span buffer = maxLength <= StackAllocThreshold + ? stackalloc char[StackAllocThreshold] + : new char[maxLength]; + + var writeIndex = 0; + + if(needsPrefix) + { + buffer[writeIndex++] = '_'; + } + + for(var i = 0; i < nameSpan.Length; i++) + { + var ch = nameSpan[i]; + // Allow letters, digits, underscore, and dot (namespace separator) + buffer[writeIndex++] = char.IsLetterOrDigit(ch) || ch is '_' or '.' ? ch : '_'; + } + + return buffer[..writeIndex].ToString(); + } + + /// + /// Converts a string to a safe C# identifier. + /// Removes all global:: prefixes and replaces non-identifier characters with underscores. + /// Uses stack allocation for small strings to reduce heap allocations. + /// + /// The name to convert. + /// The fallback value if name is null or whitespace. Default is "Generated". + /// A safe C# identifier. + public static string GetSafeIdentifier(string name, string fallback = "Generated") + { + if(string.IsNullOrWhiteSpace(name)) + return fallback; + + // Remove all global:: prefixes (not just the first one, e.g., in generic types) + var processedName = name.Replace("global::", ""); + + ReadOnlySpan nameSpan = processedName.AsSpan(); + + // Check if first char is digit (needs underscore prefix) + var needsPrefix = nameSpan.Length > 0 && char.IsDigit(nameSpan[0]); + var maxLength = nameSpan.Length + (needsPrefix ? 1 : 0); + + // Use stackalloc for small strings (up to 256 chars), otherwise use array pool + const int StackAllocThreshold = 256; + Span buffer = maxLength <= StackAllocThreshold + ? stackalloc char[StackAllocThreshold] + : new char[maxLength]; + + var writeIndex = 0; + + if(needsPrefix) + { + buffer[writeIndex++] = '_'; + } + + for(var i = 0; i < nameSpan.Length; i++) + { + var ch = nameSpan[i]; + buffer[writeIndex++] = char.IsLetterOrDigit(ch) || ch == '_' ? ch : '_'; + } + + return buffer[..writeIndex].ToString(); + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Misc.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Misc.cs new file mode 100644 index 0000000..d0f689c --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Misc.cs @@ -0,0 +1,36 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + /// + /// Returns when the method returns the non-generic + /// type (arity 0). + /// + internal static bool IsNonGenericTaskReturnType(IMethodSymbol method) + => method.ReturnType is INamedTypeSymbol { Arity: 0, Name: "Task" } named + && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; + + extension(IEnumerable source) + { + public IEnumerable<(int Index, T Item)> Index() + { + int index = 0; + foreach(var item in source) + { + yield return (index, item); + checked { index++; } + } + } + } + + extension(IReadOnlyList source) + { + public IEnumerable<(int Index, T Item)> Index() + { + for(int i = 0; i < source.Count; i++) + { + yield return (i, source[i]); + } + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.NameofResolution.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.NameofResolution.cs new file mode 100644 index 0000000..e96263a --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.NameofResolution.cs @@ -0,0 +1,39 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + /// + /// Resolves the full access path of a symbol referenced in a nameof() expression. + /// For example, resolves nameof(Key) to global::Namespace.OuterClass.InnerClass.Key + /// when Key is a member of InnerClass inside OuterClass. + /// + /// The expression inside nameof(). + /// The semantic model to use for symbol resolution. + /// The full access path if successfully resolved; otherwise, null. + public static string? ResolveNameofExpression(ExpressionSyntax expression, SemanticModel semanticModel) + { + var symbolInfo = semanticModel.GetSymbolInfo(expression); + var symbol = symbolInfo.Symbol ?? symbolInfo.CandidateSymbols.FirstOrDefault(); + + if(symbol is null) + return null; + + // If the expression already contains a member access (e.g., KeyHolder.Key), + // we need to resolve it to ensure we get the fully qualified path + return symbol.FullAccessPath; + } + + /// + /// Resolves the method symbol from a nameof() or string expression in an attribute. + /// + /// The expression inside nameof() or a string literal. + /// The semantic model to use for symbol resolution. + /// The method symbol if found; otherwise, null. + public static IMethodSymbol? ResolveMethodSymbol(ExpressionSyntax expression, SemanticModel semanticModel) + { + var symbolInfo = semanticModel.GetSymbolInfo(expression); + var symbol = symbolInfo.Symbol ?? symbolInfo.CandidateSymbols.FirstOrDefault(); + + return symbol as IMethodSymbol; + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.SymbolDisplay.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.SymbolDisplay.cs new file mode 100644 index 0000000..7a3a21b --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.SymbolDisplay.cs @@ -0,0 +1,57 @@ +namespace SourceGen.Ioc.SourceGenerator; + +/// +/// Extension methods for Roslyn symbol manipulation. +/// +internal static partial class RoslynExtensions +{ + /// The symbol + extension(ISymbol symbol) + { + /// + /// Builds the fully qualified access path for a symbol, including namespace and containing types.
+ /// For example, for a field Key inside NestClassImpl inside TestNestClass in namespace MyApp.Services, + /// returns global::MyApp.Services.TestNestClass.NestClassImpl.Key. + ///
+ /// The fully qualified access path for the symbol. + public string FullAccessPath + { + get + { + // Build path from the symbol up to its containing types (collect in reverse order, then reverse) + List pathParts = [symbol.Name]; + + var containingType = symbol.ContainingType; + while(containingType is not null) + { + pathParts.Add(containingType.Name); + containingType = containingType.ContainingType; + } + + // Reverse to get correct order (outermost type first) + pathParts.Reverse(); + + // Add namespace prefix with global:: + var containingNamespace = symbol.ContainingType?.ContainingNamespace ?? symbol.ContainingNamespace; + if(containingNamespace is not null && !containingNamespace.IsGlobalNamespace) + { + var namespacePath = containingNamespace.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + return $"{namespacePath}.{string.Join(".", pathParts)}"; + } + + // For global namespace, just prepend global:: + return $"global::{string.Join(".", pathParts)}"; + } + } + } + + extension(ITypeSymbol typeSymbol) + { + /// + /// Gets the fully qualified name of a type symbol. + /// + public string FullyQualifiedName => typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + public bool IsNullable => !typeSymbol.IsValueType || typeSymbol.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T; + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.ValueFormatting.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.ValueFormatting.cs new file mode 100644 index 0000000..6bac1b6 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.ValueFormatting.cs @@ -0,0 +1,161 @@ +using System.Globalization; + +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + public static string FormatPrimitiveConstant(ITypeSymbol? type, object? value) + { + if(type?.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T) + { + var elementType = ((INamedTypeSymbol)type).TypeArguments[0]; + return value is null ? "null" : FormatPrimitiveConstant(elementType, value); + } + + if(type?.TypeKind is TypeKind.Enum) + { + return FormatEnumLiteral((INamedTypeSymbol)type, value!); + } + + return value switch + { + null => type?.IsNullable is null or true ? "null!" : "default", + false => "false", + true => "true", + + string s => SymbolDisplay.FormatLiteral(s, quote: true), + char c => SymbolDisplay.FormatLiteral(c, quote: true), + + double.NaN => "double.NaN", + double.NegativeInfinity => "double.NegativeInfinity", + double.PositiveInfinity => "double.PositiveInfinity", + double d => $"{d.ToString("G17", CultureInfo.InvariantCulture)}d", + + float.NaN => "float.NaN", + float.NegativeInfinity => "float.NegativeInfinity", + float.PositiveInfinity => "float.PositiveInfinity", + float f => $"{f.ToString("G9", CultureInfo.InvariantCulture)}f", + + decimal d => $"{d.ToString(CultureInfo.InvariantCulture)}m", + + // Must be one of the other numeric types or an enum + object num => Convert.ToString(num, CultureInfo.InvariantCulture), + }; + + static string FormatEnumLiteral(INamedTypeSymbol enumType, object value) + { + Debug.Assert(enumType.TypeKind is TypeKind.Enum); + + foreach(ISymbol member in enumType.GetMembers()) + { + if(member is IFieldSymbol { IsConst: true, ConstantValue: { } constantValue } field) + { + if(Equals(constantValue, value)) + { + return FormatEnumField(field); + } + } + } + + bool isFlagsEnum = enumType.GetAttributes().Any(attr => + attr.AttributeClass?.Name == "FlagsAttribute" && + attr.AttributeClass.ContainingNamespace.ToDisplayString() == "System"); + + if(isFlagsEnum) + { + // Convert the value to ulong for bitwise operations + ulong numericValue = ConvertToUInt64(value); + var fields = enumType.GetMembers().OfType() + .Select((f, i) => (Index: i, Symbol: f, NumericValue: ConvertToUInt64(f.ConstantValue!))) + .ToArray(); + + // Check for any zero numeric values. + if(numericValue == 0) + { + foreach(var field in fields) + { + if(field.NumericValue == 0) + { + return FormatEnumField(field.Symbol); + } + } + } + else + { + List? matches = null; + foreach(var field in fields.OrderByDescending(f => f.NumericValue)) + { + // Greedy match of flag values from highest to lowest numeric value. + if(field.NumericValue != 0 && (numericValue & field.NumericValue) == field.NumericValue) + { + (matches ??= []).Add(field.Index); + numericValue &= ~field.NumericValue; + if(numericValue == 0) + { + break; // All bits accounted for + } + } + } + + if(numericValue == 0) + { + matches!.Sort(); // Format components using the original declaration order. + return string.Join(" | ", matches.Select(i => FormatEnumField(fields[i].Symbol))); + } + } + } + + // Value does not correspond to any combination of defined constants, just cast the numeric value. + return $"({enumType.FullyQualifiedName})({Convert.ToString(value, CultureInfo.InvariantCulture)!})"; + + static string FormatEnumField(IFieldSymbol field) + { + return $"{field.ContainingType.FullyQualifiedName}.{field.Name}"; + } + + static ulong ConvertToUInt64(object value) + { + return value switch + { + byte b => b, + sbyte sb => (ulong)sb, + short s => (ulong)s, + ushort us => us, + char c => c, + int i => (ulong)i, + uint ui => ui, + long l => (ulong)l, + ulong ul => ul, + _ => 0 + }; + } + } + } + + /// + /// Converts a parameter's explicit default value to its C# code representation. + /// + /// The default value object from IParameterSymbol.ExplicitDefaultValue. + /// The C# code string representing the default value, or null if the value is null. + public static string? ToDefaultValueCodeString(object? value) + { + if(value is null) + { + return null; + } + + return value switch + { + string s => $"\"{s.Replace("\\", "\\\\").Replace("\"", "\\\"")}\"", + char c => $"'{(c == '\'' ? "\\'" : c == '\\' ? "\\\\" : c.ToString())}'", + bool b => b ? "true" : "false", + byte or sbyte or short or ushort or int or uint => value.ToString()!, + long l => $"{l}L", + ulong ul => $"{ul}UL", + float f => float.IsNaN(f) ? "float.NaN" : float.IsPositiveInfinity(f) ? "float.PositiveInfinity" : float.IsNegativeInfinity(f) ? "float.NegativeInfinity" : $"{f.ToString(CultureInfo.InvariantCulture)}f", + double d => double.IsNaN(d) ? "double.NaN" : double.IsPositiveInfinity(d) ? "double.PositiveInfinity" : double.IsNegativeInfinity(d) ? "double.NegativeInfinity" : $"{d.ToString(CultureInfo.InvariantCulture)}d", + decimal m => $"{m.ToString(CultureInfo.InvariantCulture)}m", + _ => value.ToString()! + }; + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Wrappers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Wrappers.cs new file mode 100644 index 0000000..d2dd438 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/RoslynExtensions.Wrappers.cs @@ -0,0 +1,245 @@ +namespace SourceGen.Ioc.SourceGenerator; + +internal static partial class RoslynExtensions +{ + extension(ITypeSymbol typeSymbol) + { + /// + /// Determines whether the type is a built-in/primitive type that cannot be resolved from dependency injection. + /// This includes numeric types, string, bool, char, DateTime, Guid, TimeSpan, Uri, Type, etc. + /// + public bool IsBuiltInType + { + get + { + // Check if it's a special type (primitives, string, object, etc.) + var specialType = typeSymbol.SpecialType; + if(specialType is not SpecialType.None) + { + // These special types are built-in and cannot be resolved from DI + return specialType is + SpecialType.System_Boolean or + SpecialType.System_Char or + SpecialType.System_SByte or + SpecialType.System_Byte or + SpecialType.System_Int16 or + SpecialType.System_UInt16 or + SpecialType.System_Int32 or + SpecialType.System_UInt32 or + SpecialType.System_Int64 or + SpecialType.System_UInt64 or + SpecialType.System_Decimal or + SpecialType.System_Single or + SpecialType.System_Double or + SpecialType.System_String or + SpecialType.System_IntPtr or + SpecialType.System_UIntPtr or + SpecialType.System_Object or + SpecialType.System_DateTime; + } + + // Check for common System types by name + if(typeSymbol.ContainingNamespace?.ToDisplayString() is "System") + { + return typeSymbol.Name is + "Guid" or + "TimeSpan" or + "DateTimeOffset" or + "DateOnly" or + "TimeOnly" or + "Uri" or + "Type" or + "Version" or + "Half" or + "Int128" or + "UInt128"; + } + + return false; + } + } + + /// + /// Determines whether the type is a built-in type, or an array/collection whose element type is built-in. + /// + public bool IsBuiltInTypeOrBuiltInElement + { + get + { + // Check if it's directly a built-in type + if(typeSymbol.IsBuiltInType) + return true; + + // Check if it's an array of built-in type + if(typeSymbol is IArrayTypeSymbol arrayType) + return arrayType.ElementType.IsBuiltInType; + + // Check if it's a generic collection of built-in type + if(typeSymbol is INamedTypeSymbol namedType && namedType.IsGenericType) + { + var typeArgs = namedType.TypeArguments; + if(typeArgs.Length == 1) + { + var elementType = typeArgs[0]; + // Check if the element type is a built-in type + return elementType.IsBuiltInType; + } + } + + return false; + } + } + + /// + /// Attempts to classify the given type symbol as a supported wrapper type and extract its element type. + /// + public bool TryGetWrapperInfo(out WrapperInfo wrapperInfo) + { + // Array: T[] + if(typeSymbol is IArrayTypeSymbol arrayType) + { + wrapperInfo = new(WrapperKind.Array, arrayType.ElementType as INamedTypeSymbol); + return true; + } + + if(typeSymbol is not INamedTypeSymbol namedType) + { + wrapperInfo = default; + return false; + } + + // Func, Func, ... (last type argument is return type) + if(namedType.Arity >= 1 + && IsType(namedType, "System", "Func")) + { + wrapperInfo = new(WrapperKind.Func, namedType.TypeArguments[^1] as INamedTypeSymbol); + return true; + } + + // Arity-2 wrappers where TValue is the element type. + if(namedType.Arity == 2 + && IsTypeInNamespace(namedType, "System.Collections.Generic")) + { + var kind = namedType.Name switch + { + "IDictionary" or "IReadOnlyDictionary" or "Dictionary" => WrapperKind.Dictionary, + "KeyValuePair" => WrapperKind.KeyValuePair, + _ => WrapperKind.None + }; + + if(kind is not WrapperKind.None) + { + wrapperInfo = new(kind, namedType.TypeArguments[1] as INamedTypeSymbol); + return true; + } + } + + // Arity-1 wrappers where T is the element type. + if(namedType.Arity == 1) + { + var elementType = namedType.TypeArguments[0] as INamedTypeSymbol; + + if(namedType.OriginalDefinition.SpecialType is SpecialType.System_Collections_Generic_IEnumerable_T) + { + wrapperInfo = new(WrapperKind.Enumerable, elementType); + return true; + } + + if(IsTypeInNamespace(namedType, "System.Collections.Generic")) + { + var kind = namedType.Name switch + { + "IReadOnlyCollection" => WrapperKind.ReadOnlyCollection, + "ICollection" => WrapperKind.Collection, + "IReadOnlyList" => WrapperKind.ReadOnlyList, + "IList" => WrapperKind.List, + _ => WrapperKind.None + }; + + if(kind is not WrapperKind.None) + { + wrapperInfo = new(kind, elementType); + return true; + } + } + + if(IsType(namedType, "System", "Lazy")) + { + wrapperInfo = new(WrapperKind.Lazy, elementType); + return true; + } + + if(IsTypeInNamespace(namedType, "System.Threading.Tasks")) + { + var kind = namedType.Name switch + { + "Task" => WrapperKind.Task, + "ValueTask" => WrapperKind.ValueTask, + _ => WrapperKind.None + }; + + if(kind is not WrapperKind.None) + { + wrapperInfo = new(kind, elementType); + return true; + } + } + } + + wrapperInfo = default; + return false; + + static bool IsType(INamedTypeSymbol symbol, string @namespace, string name) + => symbol.Name == name && IsTypeInNamespace(symbol, @namespace); + + static bool IsTypeInNamespace(INamedTypeSymbol symbol, string @namespace) + => symbol.ContainingNamespace.ToDisplayString() == @namespace; + } + } + + extension(WrapperKind kind) + { + /// + /// Returns when the wrapper kind is a collection wrapper. + /// + public bool IsCollectionWrapperKind() + => kind is WrapperKind.Enumerable + or WrapperKind.ReadOnlyCollection + or WrapperKind.Collection + or WrapperKind.ReadOnlyList + or WrapperKind.List + or WrapperKind.Array; + } + + /// + /// Returns when the wrapper nesting should be downgraded to unsupported. + /// + public static bool IsUnsupportedWrapperNesting(WrapperKind outerKind, WrapperKind innerKind, bool isAfterCollection) + { + // Task + if(outerKind is WrapperKind.Task && innerKind is not WrapperKind.None) + return true; + + // Wrapper + if(outerKind is not WrapperKind.Task && innerKind is WrapperKind.Task) + return true; + + // ValueTask is only allowed at root; nested ValueTask is unsupported. + if(innerKind is WrapperKind.ValueTask) + return true; + + // A collection wrapper was seen earlier, and we now have a non-collection wrapper + // containing another non-collection wrapper. This is 2+ non-collection layers after + // collection, e.g., IEnumerable>>. + if(isAfterCollection + && !outerKind.IsCollectionWrapperKind() + && !innerKind.IsCollectionWrapperKind() + && innerKind is not WrapperKind.None) + return true; + + return false; + } + + public static bool IsEnumerableType(string nameWithoutGeneric) => + nameWithoutGeneric is "global::System.Collections.Generic.IEnumerable" or "System.Collections.Generic.IEnumerable" or "IEnumerable"; +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/TypeParameterSubstitution.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/TypeParameterSubstitution.cs new file mode 100644 index 0000000..a8ce3c9 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Roslyn/TypeParameterSubstitution.cs @@ -0,0 +1,136 @@ +namespace SourceGen.Ioc.SourceGenerator.Roslyn; + +internal static class TypeParameterSubstitution +{ + public static string SubstituteTypeArguments(string typeName, TypeArgMap typeArgMap) + { + if(typeArgMap.IsEmpty) + { + return typeName; + } + + // Fast path: check if any substitution is needed + var typeNameSpan = typeName.AsSpan(); + bool needsSubstitution = false; + foreach(var (key, _) in typeArgMap) + { + if(ContainsTypeParameter(typeNameSpan, key.AsSpan())) + { + needsSubstitution = true; + break; + } + } + + if(!needsSubstitution) + { + return typeName; + } + + return SubstituteTypeArgumentsCore(typeNameSpan, typeArgMap.AsSpan()); + } + + public static string ReplaceTypeParameter(string typeName, string typeParam, string actualArg) + { + var typeNameSpan = typeName.AsSpan(); + + // Fast path: check if substitution is needed + if(!ContainsTypeParameter(typeNameSpan, typeParam.AsSpan())) + { + return typeName; + } + + // Delegate to core implementation with single-element span + Span singleEntry = [new(typeParam, actualArg)]; + return SubstituteTypeArgumentsCore(typeNameSpan, singleEntry); + } + + private static string SubstituteTypeArgumentsCore( + ReadOnlySpan typeNameSpan, + ReadOnlySpan sortedEntries) + { + var result = new StringBuilder(typeNameSpan.Length + 32); + int i = 0; + + while(i < typeNameSpan.Length) + { + // Check if current position is a valid identifier start + bool isValidStart = i == 0 || !IsIdentifierChar(typeNameSpan[i - 1]); + + if(isValidStart && TryMatchTypeParameter(typeNameSpan, i, sortedEntries, out var match, out int matchLength)) + { + result.Append(match); + i += matchLength; + } + else + { + result.Append(typeNameSpan[i]); + i++; + } + } + + return result.ToString(); + } + + private static bool TryMatchTypeParameter( + ReadOnlySpan typeNameSpan, + int position, + ReadOnlySpan sortedEntries, + [NotNullWhen(true)] out string? replacement, + out int matchLength) + { + foreach(var (key, value) in sortedEntries) + { + var typeParamSpan = key.AsSpan(); + int paramLength = typeParamSpan.Length; + + if(position + paramLength <= typeNameSpan.Length && + typeNameSpan.Slice(position, paramLength).SequenceEqual(typeParamSpan)) + { + // Check if it's a whole word (ends at identifier boundary) + bool isEnd = position + paramLength == typeNameSpan.Length + || !IsIdentifierChar(typeNameSpan[position + paramLength]); + + if(isEnd) + { + replacement = value; + matchLength = paramLength; + return true; + } + } + } + + replacement = null; + matchLength = 0; + return false; + } + + private static bool ContainsTypeParameter(ReadOnlySpan typeName, ReadOnlySpan typeParam) + { + int index = 0; + while(index <= typeName.Length - typeParam.Length) + { + int pos = typeName[index..].IndexOf(typeParam, StringComparison.Ordinal); + if(pos < 0) + { + return false; + } + + int absolutePos = index + pos; + bool isStart = absolutePos == 0 + || !IsIdentifierChar(typeName[absolutePos - 1]); + bool isEnd = absolutePos + typeParam.Length == typeName.Length + || !IsIdentifierChar(typeName[absolutePos + typeParam.Length]); + + if(isStart && isEnd) + { + return true; + } + + index = absolutePos + 1; + } + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsIdentifierChar(char c) => char.IsLetterOrDigit(c) || c == '_'; +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs deleted file mode 100644 index c4e3424..0000000 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/RoslynExtensions.cs +++ /dev/null @@ -1,1049 +0,0 @@ -ο»Ώusing System.Globalization; - -namespace SourceGen.Ioc.SourceGenerator; - -/// -/// Extension methods for Roslyn symbol manipulation. -/// -internal static class RoslynExtensions -{ - /// The symbol - extension(ISymbol symbol) - { - /// - /// Builds the fully qualified access path for a symbol, including namespace and containing types.
- /// For example, for a field Key inside NestClassImpl inside TestNestClass in namespace MyApp.Services, - /// returns global::MyApp.Services.TestNestClass.NestClassImpl.Key. - ///
- /// The fully qualified access path for the symbol. - public string FullAccessPath - { - get - { - // Build path from the symbol up to its containing types (collect in reverse order, then reverse) - List pathParts = [symbol.Name]; - - var containingType = symbol.ContainingType; - while(containingType is not null) - { - pathParts.Add(containingType.Name); - containingType = containingType.ContainingType; - } - - // Reverse to get correct order (outermost type first) - pathParts.Reverse(); - - // Add namespace prefix with global:: - var containingNamespace = symbol.ContainingType?.ContainingNamespace ?? symbol.ContainingNamespace; - if(containingNamespace is not null && !containingNamespace.IsGlobalNamespace) - { - var namespacePath = containingNamespace.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - return $"{namespacePath}.{string.Join(".", pathParts)}"; - } - - // For global namespace, just prepend global:: - return $"global::{string.Join(".", pathParts)}"; - } - } - } - - extension(ITypeSymbol typeSymbol) - { - /// - /// Gets the fully qualified name of a type symbol. - /// - public string FullyQualifiedName => typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - - public bool IsNullable => !typeSymbol.IsValueType || typeSymbol.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T; - - /// - /// Determines whether the type is a built-in/primitive type that cannot be resolved from dependency injection. - /// This includes numeric types, string, bool, char, DateTime, Guid, TimeSpan, Uri, Type, etc. - /// - public bool IsBuiltInType - { - get - { - // Check if it's a special type (primitives, string, object, etc.) - var specialType = typeSymbol.SpecialType; - if(specialType is not SpecialType.None) - { - // These special types are built-in and cannot be resolved from DI - return specialType is - SpecialType.System_Boolean or - SpecialType.System_Char or - SpecialType.System_SByte or - SpecialType.System_Byte or - SpecialType.System_Int16 or - SpecialType.System_UInt16 or - SpecialType.System_Int32 or - SpecialType.System_UInt32 or - SpecialType.System_Int64 or - SpecialType.System_UInt64 or - SpecialType.System_Decimal or - SpecialType.System_Single or - SpecialType.System_Double or - SpecialType.System_String or - SpecialType.System_IntPtr or - SpecialType.System_UIntPtr or - SpecialType.System_Object or - SpecialType.System_DateTime; - } - - // Check for common System types by name - if(typeSymbol.ContainingNamespace?.ToDisplayString() is "System") - { - return typeSymbol.Name is - "Guid" or - "TimeSpan" or - "DateTimeOffset" or - "DateOnly" or - "TimeOnly" or - "Uri" or - "Type" or - "Version" or - "Half" or - "Int128" or - "UInt128"; - } - - return false; - } - } - - /// - /// Determines whether the type is a built-in type, or an array/collection whose element type is built-in. - /// - public bool IsBuiltInTypeOrBuiltInElement - { - get - { - // Check if it's directly a built-in type - if(typeSymbol.IsBuiltInType) - return true; - - // Check if it's an array of built-in type - if(typeSymbol is IArrayTypeSymbol arrayType) - return arrayType.ElementType.IsBuiltInType; - - // Check if it's a generic collection of built-in type - if(typeSymbol is INamedTypeSymbol namedType && namedType.IsGenericType) - { - var typeArgs = namedType.TypeArguments; - if(typeArgs.Length == 1) - { - var elementType = typeArgs[0]; - // Check if the element type is a built-in type - return elementType.IsBuiltInType; - } - } - - return false; - } - } - - public bool ContainsGenericParameters - { - get - { - if(typeSymbol.TypeKind is TypeKind.TypeParameter or TypeKind.Error) - { - return true; - } - - if(typeSymbol is INamedTypeSymbol namedTypeSymbol) - { - if(namedTypeSymbol.IsUnboundGenericType) - { - return true; - } - - for(; namedTypeSymbol != null; namedTypeSymbol = namedTypeSymbol.ContainingType) - { - if(namedTypeSymbol.TypeArguments.Any(arg => arg.ContainsGenericParameters)) - { - return true; - } - } - } - - return false; - } - } - - public INamedTypeSymbol? GetCompatibleGenericBaseType([NotNullWhen(true)] INamedTypeSymbol? genericType) - { - if(genericType is null) - { - return null; - } - - Debug.Assert(genericType.IsGenericTypeDefinition); - - if(genericType.TypeKind is TypeKind.Interface) - { - foreach(INamedTypeSymbol interfaceType in typeSymbol.AllInterfaces) - { - if(IsMatchingGenericType(interfaceType, genericType)) - { - return interfaceType; - } - } - } - - for(INamedTypeSymbol? current = typeSymbol as INamedTypeSymbol; current != null; current = current.BaseType) - { - if(IsMatchingGenericType(current, genericType)) - { - return current; - } - } - - return null; - - static bool IsMatchingGenericType(INamedTypeSymbol candidate, INamedTypeSymbol baseType) - { - return candidate.IsGenericType && SymbolEqualityComparer.Default.Equals(candidate.ConstructedFrom, baseType); - } - } - } - - extension(INamedTypeSymbol typeSymbol) - { - public bool IsGenericTypeDefinition => typeSymbol is { IsGenericType: true, IsDefinition: true }; - - /// - /// Gets the type parameters source for this type symbol. - /// For unbound generic types and constructed generic types, returns parameters from the original definition. - /// This allows matching type parameter names (TRequest, TResponse) with type arguments (TestRequest, List<string>). - /// - public ImmutableArray TypeParametersSource => - typeSymbol.IsGenericType - ? typeSymbol.OriginalDefinition?.TypeParameters ?? typeSymbol.TypeParameters - : typeSymbol.TypeParameters; - - /// - /// Determines whether the type is a nested open generic type. - /// A nested open generic is a generic type where any type argument itself contains generic parameters. - /// For example: IGeneric<IGeneric2<T>> is a nested open generic. - /// But IGeneric<T> or IGeneric<int> are not. - /// - public bool IsNestedOpenGeneric - { - get - { - if(!typeSymbol.IsGenericType) - { - return false; - } - - // For unbound generic types (e.g., IRepository<>), TypeArguments contains error types - // which should not be considered as nested open generics - if(typeSymbol.IsUnboundGenericType) - { - return false; - } - - // Check if any type argument contains generic parameters - foreach(var typeArg in typeSymbol.TypeArguments) - { - // If the type argument is not a simple type parameter (T, T1, etc.) - // but contains generic parameters, it's a nested open generic - if(typeArg.TypeKind != TypeKind.TypeParameter && typeArg.ContainsGenericParameters) - { - return true; - } - } - - return false; - } - } - - public IMethodSymbol? PrimaryConstructor - { - get - { - foreach(var ctor in typeSymbol.Constructors) - { - if(ctor.IsImplicitlyDeclared) - continue; - - var syntaxRef = ctor.DeclaringSyntaxReferences.FirstOrDefault(); - if(syntaxRef?.GetSyntax() is TypeDeclarationSyntax) - return ctor; - } - - return null; - } - } - - public IMethodSymbol? PrimaryOrMostParametersConstructor - { - get - { - IMethodSymbol? bestCtor = null; - int maxParameters = -1; - foreach(var ctor in typeSymbol.Constructors) - { - if(ctor.IsImplicitlyDeclared) - continue; - - var syntaxRef = ctor.DeclaringSyntaxReferences.FirstOrDefault(); - // Primary constructor - if(syntaxRef?.GetSyntax() is TypeDeclarationSyntax) - return ctor; - - if(ctor.IsStatic) - continue; - - if(ctor.DeclaredAccessibility is not (Accessibility.Public or Accessibility.Internal)) - continue; - - // Find constructor with most parameters - if(ctor.Parameters.Length > maxParameters) - { - maxParameters = ctor.Parameters.Length; - bestCtor = ctor; - } - } - return bestCtor; - } - } - } - - extension(AttributeData attributeData) - { - /// - /// Gets a named argument value from an attribute data. - /// - public T? GetNamedArgument(string name, T? defaultValue = default) - { - foreach(var namedArg in attributeData.NamedArguments) - { - if(namedArg.Key == name) - { - if(namedArg.Value.IsNull) - { - return defaultValue; - } - - return (T?)namedArg.Value.Value; - } - } - - return defaultValue; - } - - /// - /// Tries to get a named argument value from an attribute data.
- /// If the argument is not found, returns HasArg = false. - ///
- public (bool HasArg, T? Value) TryGetNamedArgument(string name, T? defaultValue = default) - { - foreach(var namedArg in attributeData.NamedArguments) - { - if(namedArg.Key == name) - { - if(namedArg.Value.IsNull) - { - return (true, defaultValue); - } - - return (true, (T?)namedArg.Value.Value); - } - } - - return (false, defaultValue); - } - - /// - /// Checks if a named argument was explicitly set in an attribute. - /// - public bool HasNamedArgument(string name) - { - foreach(var namedArg in attributeData.NamedArguments) - { - if(namedArg.Key == name) - { - return true; - } - } - - return false; - } - - /// - /// Tries to get the original syntax for a named argument, especially for expressions. - /// When a is provided, resolves the full access path of the referenced symbol. - /// - /// The name of the argument to find. - /// Optional semantic model to resolve full access paths for nameof() expressions. - /// The resolved symbol path if it's a expression; otherwise, null. - public string? TryGetNameof(string argumentName, SemanticModel? semanticModel = null) - { - var syntaxReference = attributeData.ApplicationSyntaxReference; - if(syntaxReference is null) - return null; - - var syntax = syntaxReference.GetSyntax(); - if(syntax is not AttributeSyntax attributeSyntax) - return null; - - var argumentList = attributeSyntax.ArgumentList; - if(argumentList is null) - return null; - - foreach(var argument in argumentList.Arguments) - { - // Check if this is a named argument with the correct name - if(argument.NameEquals?.Name.Identifier.Text == argumentName) - { - // Check if the expression is a nameof() invocation - if(argument.Expression is InvocationExpressionSyntax invocation && - invocation.Expression is IdentifierNameSyntax identifierName && - identifierName.Identifier.Text == "nameof") - { - // Extract the argument inside nameof() and return just that expression - if(invocation.ArgumentList.Arguments.Count == 1) - { - var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; - - // If semantic model is provided, try to resolve the full access path - if(semanticModel is not null) - { - var resolvedPath = ResolveNameofExpression(nameofArgument, semanticModel); - if(resolvedPath is not null) - return resolvedPath; - } - - return nameofArgument.ToFullString().Trim(); - } - } - } - } - - return null; - } - - /// - /// Tries to extract the expression from a constructor argument of an attribute. - /// - /// The index of the constructor argument to check. - /// Optional semantic model to resolve full access paths for nameof() expressions. - /// The resolved symbol path if it's a expression; otherwise, null. - public string? TryGetNameofFromConstructorArg(int argumentIndex, SemanticModel? semanticModel = null) - { - var syntaxReference = attributeData.ApplicationSyntaxReference; - if(syntaxReference is null) - return null; - - var syntax = syntaxReference.GetSyntax(); - if(syntax is not AttributeSyntax attributeSyntax) - return null; - - var argumentList = attributeSyntax.ArgumentList; - if(argumentList is null || argumentList.Arguments.Count <= argumentIndex) - return null; - - var argument = argumentList.Arguments[argumentIndex]; - - // Skip named arguments (they don't count as constructor arguments) - if(argument.NameEquals is not null) - return null; - - // Check if the expression is a nameof() invocation - if(argument.Expression is InvocationExpressionSyntax invocation && - invocation.Expression is IdentifierNameSyntax identifierName && - identifierName.Identifier.Text == "nameof") - { - // Extract the argument inside nameof() and return just that expression - if(invocation.ArgumentList.Arguments.Count == 1) - { - var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; - - // If semantic model is provided, try to resolve the full access path - if(semanticModel is not null) - { - var resolvedPath = ResolveNameofExpression(nameofArgument, semanticModel); - if(resolvedPath is not null) - return resolvedPath; - } - - return nameofArgument.ToFullString().Trim(); - } - } - - return null; - } - } - - extension(TypedConstant constant) - { - public string GetPrimitiveConstantString() => FormatPrimitiveConstant(constant.Type, constant.Value); - } - - public static string FormatPrimitiveConstant(ITypeSymbol? type, object? value) - { - if(type?.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T) - { - var elementType = ((INamedTypeSymbol)type).TypeArguments[0]; - return value is null ? "null" : FormatPrimitiveConstant(elementType, value); - } - - if(type?.TypeKind is TypeKind.Enum) - { - return FormatEnumLiteral((INamedTypeSymbol)type, value!); - } - - return value switch - { - null => type?.IsNullable is null or true ? "null!" : "default", - false => "false", - true => "true", - - string s => SymbolDisplay.FormatLiteral(s, quote: true), - char c => SymbolDisplay.FormatLiteral(c, quote: true), - - double.NaN => "double.NaN", - double.NegativeInfinity => "double.NegativeInfinity", - double.PositiveInfinity => "double.PositiveInfinity", - double d => $"{d.ToString("G17", CultureInfo.InvariantCulture)}d", - - float.NaN => "float.NaN", - float.NegativeInfinity => "float.NegativeInfinity", - float.PositiveInfinity => "float.PositiveInfinity", - float f => $"{f.ToString("G9", CultureInfo.InvariantCulture)}f", - - decimal d => $"{d.ToString(CultureInfo.InvariantCulture)}m", - - // Must be one of the other numeric types or an enum - object num => Convert.ToString(num, CultureInfo.InvariantCulture), - }; - - static string FormatEnumLiteral(INamedTypeSymbol enumType, object value) - { - Debug.Assert(enumType.TypeKind is TypeKind.Enum); - - foreach(ISymbol member in enumType.GetMembers()) - { - if(member is IFieldSymbol { IsConst: true, ConstantValue: { } constantValue } field) - { - if(Equals(constantValue, value)) - { - return FormatEnumField(field); - } - } - } - - bool isFlagsEnum = enumType.GetAttributes().Any(attr => - attr.AttributeClass?.Name == "FlagsAttribute" && - attr.AttributeClass.ContainingNamespace.ToDisplayString() == "System"); - - if(isFlagsEnum) - { - // Convert the value to ulong for bitwise operations - ulong numericValue = ConvertToUInt64(value); - var fields = enumType.GetMembers().OfType() - .Select((f, i) => (Index: i, Symbol: f, NumericValue: ConvertToUInt64(f.ConstantValue!))) - .ToArray(); - - // Check for any zero numeric values. - if(numericValue == 0) - { - foreach(var field in fields) - { - if(field.NumericValue == 0) - { - return FormatEnumField(field.Symbol); - } - } - } - else - { - List? matches = null; - foreach(var field in fields.OrderByDescending(f => f.NumericValue)) - { - // Greedy match of flag values from highest to lowest numeric value. - if(field.NumericValue != 0 && (numericValue & field.NumericValue) == field.NumericValue) - { - (matches ??= []).Add(field.Index); - numericValue &= ~field.NumericValue; - if(numericValue == 0) - { - break; // All bits accounted for - } - } - } - - if(numericValue == 0) - { - matches!.Sort(); // Format components using the original declaration order. - return string.Join(" | ", matches.Select(i => FormatEnumField(fields[i].Symbol))); - } - } - } - - // Value does not correspond to any combination of defined constants, just cast the numeric value. - return $"({enumType.FullyQualifiedName})({Convert.ToString(value, CultureInfo.InvariantCulture)!})"; - - static string FormatEnumField(IFieldSymbol field) - { - return $"{field.ContainingType.FullyQualifiedName}.{field.Name}"; - } - - static ulong ConvertToUInt64(object value) - { - return value switch - { - byte b => b, - sbyte sb => (ulong)sb, - short s => (ulong)s, - ushort us => us, - char c => c, - int i => (ulong)i, - uint ui => ui, - long l => (ulong)l, - ulong ul => ul, - _ => 0 - }; - } - } - } - - /// - /// Converts a string to a safe C# namespace name. - /// Replaces invalid characters (like '-') with underscores, preserving '.' as namespace separator. - /// - /// The namespace name to convert. - /// A safe C# namespace name. - public static string GetSafeNamespace(string name) - { - if(string.IsNullOrWhiteSpace(name)) - return "Generated"; - - ReadOnlySpan nameSpan = name.AsSpan(); - - // Check if first char is digit (needs underscore prefix) - var needsPrefix = nameSpan.Length > 0 && char.IsDigit(nameSpan[0]); - var maxLength = nameSpan.Length + (needsPrefix ? 1 : 0); - - // Use stackalloc for small strings (up to 256 chars), otherwise use array pool - const int StackAllocThreshold = 256; - Span buffer = maxLength <= StackAllocThreshold - ? stackalloc char[StackAllocThreshold] - : new char[maxLength]; - - var writeIndex = 0; - - if(needsPrefix) - { - buffer[writeIndex++] = '_'; - } - - for(var i = 0; i < nameSpan.Length; i++) - { - var ch = nameSpan[i]; - // Allow letters, digits, underscore, and dot (namespace separator) - buffer[writeIndex++] = char.IsLetterOrDigit(ch) || ch is '_' or '.' ? ch : '_'; - } - - return buffer[..writeIndex].ToString(); - } - - /// - /// Converts a string to a safe C# identifier. - /// Removes all global:: prefixes and replaces non-identifier characters with underscores. - /// Uses stack allocation for small strings to reduce heap allocations. - /// - /// The name to convert. - /// The fallback value if name is null or whitespace. Default is "Generated". - /// A safe C# identifier. - public static string GetSafeIdentifier(string name, string fallback = "Generated") - { - if(string.IsNullOrWhiteSpace(name)) - return fallback; - - // Remove all global:: prefixes (not just the first one, e.g., in generic types) - var processedName = name.Replace("global::", ""); - - ReadOnlySpan nameSpan = processedName.AsSpan(); - - // Check if first char is digit (needs underscore prefix) - var needsPrefix = nameSpan.Length > 0 && char.IsDigit(nameSpan[0]); - var maxLength = nameSpan.Length + (needsPrefix ? 1 : 0); - - // Use stackalloc for small strings (up to 256 chars), otherwise use array pool - const int StackAllocThreshold = 256; - Span buffer = maxLength <= StackAllocThreshold - ? stackalloc char[StackAllocThreshold] - : new char[maxLength]; - - var writeIndex = 0; - - if(needsPrefix) - { - buffer[writeIndex++] = '_'; - } - - for(var i = 0; i < nameSpan.Length; i++) - { - var ch = nameSpan[i]; - buffer[writeIndex++] = char.IsLetterOrDigit(ch) || ch == '_' ? ch : '_'; - } - - return buffer[..writeIndex].ToString(); - } - - /// - /// Checks if sourceType is assignable to targetType. - /// Returns true if a value of sourceType can be assigned to a variable of targetType. - /// - /// The target type (e.g., parameter type). - /// The source type (e.g., value type). - /// True if sourceType is assignable to targetType. - public static bool IsAssignable(ITypeSymbol targetType, ITypeSymbol sourceType) - { - // Exact match - if(SymbolEqualityComparer.Default.Equals(targetType, sourceType)) - return true; - - // Handle nullable types - if target is nullable and source type is the underlying type - if(targetType.OriginalDefinition.SpecialType is SpecialType.System_Nullable_T - && targetType is INamedTypeSymbol nullableTarget) - { - var underlyingType = nullableTarget.TypeArguments.FirstOrDefault(); - if(underlyingType is not null && SymbolEqualityComparer.Default.Equals(underlyingType, sourceType)) - return true; - } - - // Handle object type - any type is assignable to object - if(targetType.SpecialType is SpecialType.System_Object) - return true; - - // Handle inheritance - target type should be a base type or interface of source type - if(sourceType.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, targetType))) - return true; - - var currentBase = sourceType.BaseType; - while(currentBase is not null) - { - if(SymbolEqualityComparer.Default.Equals(currentBase, targetType)) - return true; - currentBase = currentBase.BaseType; - } - - return false; - } - - public static bool IsEnumerableType(string nameWithoutGeneric) => - nameWithoutGeneric is "global::System.Collections.Generic.IEnumerable" or "System.Collections.Generic.IEnumerable" or "IEnumerable"; - - /// - /// Read-only collection type name (without generic part): IReadOnlyCollection<T>. - /// - private static readonly HashSet s_readOnlyCollectionTypes = new(StringComparer.Ordinal) - { - "global::System.Collections.Generic.IReadOnlyCollection", - "System.Collections.Generic.IReadOnlyCollection", - "IReadOnlyCollection" - }; - public static bool IsReadOnlyCollectionType(string nameWithoutGeneric) => - s_readOnlyCollectionTypes.Contains(nameWithoutGeneric); - - /// - /// Read-only list type name (without generic part): IReadOnlyList<T>. - /// - private static readonly HashSet s_readOnlyListTypes = new(StringComparer.Ordinal) - { - "global::System.Collections.Generic.IReadOnlyList", - "System.Collections.Generic.IReadOnlyList", - "IReadOnlyList" - }; - public static bool IsReadOnlyListType(string nameWithoutGeneric) => - s_readOnlyListTypes.Contains(nameWithoutGeneric); - - /// - /// Mutable collection interface type name (without generic part): ICollection<T>. - /// - private static readonly HashSet s_collectionTypes = new(StringComparer.Ordinal) - { - "global::System.Collections.Generic.ICollection", - "System.Collections.Generic.ICollection", - "ICollection" - }; - public static bool IsCollectionType(string nameWithoutGeneric) => - s_collectionTypes.Contains(nameWithoutGeneric); - - /// - /// Mutable list interface type name (without generic part): IList<T>. - /// - private static readonly HashSet s_listTypes = new(StringComparer.Ordinal) - { - "global::System.Collections.Generic.IList", - "System.Collections.Generic.IList", - "IList" - }; - public static bool IsListType(string nameWithoutGeneric) => - s_listTypes.Contains(nameWithoutGeneric); - - /// - /// Resolves the full access path of a symbol referenced in a nameof() expression. - /// For example, resolves nameof(Key) to global::Namespace.OuterClass.InnerClass.Key - /// when Key is a member of InnerClass inside OuterClass. - /// - /// The expression inside nameof(). - /// The semantic model to use for symbol resolution. - /// The full access path if successfully resolved; otherwise, null. - public static string? ResolveNameofExpression(ExpressionSyntax expression, SemanticModel semanticModel) - { - var symbolInfo = semanticModel.GetSymbolInfo(expression); - var symbol = symbolInfo.Symbol ?? symbolInfo.CandidateSymbols.FirstOrDefault(); - - if(symbol is null) - return null; - - // If the expression already contains a member access (e.g., KeyHolder.Key), - // we need to resolve it to ensure we get the fully qualified path - return symbol.FullAccessPath; - } - - /// - /// Resolves the method symbol from a nameof() or string expression in an attribute. - /// - /// The expression inside nameof() or a string literal. - /// The semantic model to use for symbol resolution. - /// The method symbol if found; otherwise, null. - public static IMethodSymbol? ResolveMethodSymbol(ExpressionSyntax expression, SemanticModel semanticModel) - { - var symbolInfo = semanticModel.GetSymbolInfo(expression); - var symbol = symbolInfo.Symbol ?? symbolInfo.CandidateSymbols.FirstOrDefault(); - - return symbol as IMethodSymbol; - } - - /// - /// Returns when the method returns the non-generic - /// type (arity 0). - /// - internal static bool IsNonGenericTaskReturnType(IMethodSymbol method) - => method.ReturnType is INamedTypeSymbol { Arity: 0, Name: "Task" } named - && named.ContainingNamespace.ToDisplayString() == "System.Threading.Tasks"; - - extension(IEnumerable source) - { - public IEnumerable<(int Index, T Item)> Index() - { - int index = 0; - foreach(var item in source) - { - yield return (index, item); - checked { index++; } - } - } - } - - extension(IReadOnlyList source) - { - public IEnumerable<(int Index, T Item)> Index() - { - for(int i = 0; i < source.Count; i++) - { - yield return (i, source[i]); - } - } - } - - #region Type Parameter Substitution - - /// - /// Substitutes multiple type parameters in a type name with actual type arguments. - /// Uses Span-based processing to minimize string allocations. - /// - /// The type name containing type parameters to substitute. - /// A map of type parameter names to their actual type arguments. - /// The type name with all type parameters substituted. - public static string SubstituteTypeArguments(string typeName, TypeArgMap typeArgMap) - { - if(typeArgMap.IsEmpty) - { - return typeName; - } - - // Fast path: check if any substitution is needed - var typeNameSpan = typeName.AsSpan(); - bool needsSubstitution = false; - foreach(var (key, _) in typeArgMap) - { - if(ContainsTypeParameter(typeNameSpan, key.AsSpan())) - { - needsSubstitution = true; - break; - } - } - - if(!needsSubstitution) - { - return typeName; - } - - return SubstituteTypeArgumentsCore(typeNameSpan, typeArgMap.AsSpan()); - } - - /// - /// Replaces a single type parameter with an actual type argument in a type name. - /// - /// The type name containing the type parameter. - /// The type parameter name to replace (e.g., "T"). - /// The actual type argument to substitute (e.g., "string"). - /// The type name with the type parameter replaced. - public static string ReplaceTypeParameter(string typeName, string typeParam, string actualArg) - { - var typeNameSpan = typeName.AsSpan(); - - // Fast path: check if substitution is needed - if(!ContainsTypeParameter(typeNameSpan, typeParam.AsSpan())) - { - return typeName; - } - - // Delegate to core implementation with single-element span - Span singleEntry = [new(typeParam, actualArg)]; - return SubstituteTypeArgumentsCore(typeNameSpan, singleEntry); - } - - /// - /// Core implementation for type parameter substitution. - /// Performs all substitutions in a single pass using StringBuilder. - /// - /// The type name span to process. - /// Entries sorted by key length descending for correct matching priority. - /// The type name with all type parameters substituted. - private static string SubstituteTypeArgumentsCore( - ReadOnlySpan typeNameSpan, - ReadOnlySpan sortedEntries) - { - var result = new StringBuilder(typeNameSpan.Length + 32); - int i = 0; - - while(i < typeNameSpan.Length) - { - // Check if current position is a valid identifier start - bool isValidStart = i == 0 || !IsIdentifierChar(typeNameSpan[i - 1]); - - if(isValidStart && TryMatchTypeParameter(typeNameSpan, i, sortedEntries, out var match, out int matchLength)) - { - result.Append(match); - i += matchLength; - } - else - { - result.Append(typeNameSpan[i]); - i++; - } - } - - return result.ToString(); - } - - /// - /// Tries to match a type parameter at the given position. - /// - /// True if a match was found, with the replacement value and match length. - private static bool TryMatchTypeParameter( - ReadOnlySpan typeNameSpan, - int position, - ReadOnlySpan sortedEntries, - [NotNullWhen(true)] out string? replacement, - out int matchLength) - { - foreach(var (key, value) in sortedEntries) - { - var typeParamSpan = key.AsSpan(); - int paramLength = typeParamSpan.Length; - - if(position + paramLength <= typeNameSpan.Length && - typeNameSpan.Slice(position, paramLength).SequenceEqual(typeParamSpan)) - { - // Check if it's a whole word (ends at identifier boundary) - bool isEnd = position + paramLength == typeNameSpan.Length - || !IsIdentifierChar(typeNameSpan[position + paramLength]); - - if(isEnd) - { - replacement = value; - matchLength = paramLength; - return true; - } - } - } - - replacement = null; - matchLength = 0; - return false; - } - - /// - /// Checks if the type name contains the type parameter as a whole word. - /// - private static bool ContainsTypeParameter(ReadOnlySpan typeName, ReadOnlySpan typeParam) - { - int index = 0; - while(index <= typeName.Length - typeParam.Length) - { - int pos = typeName[index..].IndexOf(typeParam, StringComparison.Ordinal); - if(pos < 0) - { - return false; - } - - int absolutePos = index + pos; - bool isStart = absolutePos == 0 - || !IsIdentifierChar(typeName[absolutePos - 1]); - bool isEnd = absolutePos + typeParam.Length == typeName.Length - || !IsIdentifierChar(typeName[absolutePos + typeParam.Length]); - - if(isStart && isEnd) - { - return true; - } - - index = absolutePos + 1; - } - return false; - } - - /// - /// Checks if a character can be part of an identifier. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static bool IsIdentifierChar(char c) => char.IsLetterOrDigit(c) || c == '_'; - - #endregion - - /// - /// Converts a parameter's explicit default value to its C# code representation. - /// - /// The default value object from IParameterSymbol.ExplicitDefaultValue. - /// The C# code string representing the default value, or null if the value is null. - public static string? ToDefaultValueCodeString(object? value) - { - if(value is null) - { - return null; - } - - return value switch - { - string s => $"\"{s.Replace("\\", "\\\\").Replace("\"", "\\\"")}\"", - char c => $"'{(c == '\'' ? "\\'" : c == '\\' ? "\\\\" : c.ToString())}'", - bool b => b ? "true" : "false", - byte or sbyte or short or ushort or int or uint => value.ToString()!, - long l => $"{l}L", - ulong ul => $"{ul}UL", - float f => float.IsNaN(f) ? "float.NaN" : float.IsPositiveInfinity(f) ? "float.PositiveInfinity" : float.IsNegativeInfinity(f) ? "float.NegativeInfinity" : $"{f.ToString(CultureInfo.InvariantCulture)}f", - double d => double.IsNaN(d) ? "double.NaN" : double.IsPositiveInfinity(d) ? "double.PositiveInfinity" : double.IsNegativeInfinity(d) ? "double.NegativeInfinity" : $"{d.ToString(CultureInfo.InvariantCulture)}d", - decimal m => $"{m.ToString(CultureInfo.InvariantCulture)}m", - _ => value.ToString()! - }; - } -} diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.AspNetCore.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.AspNetCore.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.AspNetCore.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.AspNetCore.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Basic.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Basic.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Basic.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Basic.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Collections.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Collections.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Collections.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Collections.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Decorators.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Decorators.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Decorators.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Decorators.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Factory.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Factory.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Factory.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Factory.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Generics.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Generics.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Generics.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Generics.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ImportModule.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.ImportModule.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ImportModule.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.ImportModule.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Injection.spec.md similarity index 91% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Injection.spec.md index e9991bd..bb1d851 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Injection.spec.md +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Injection.spec.md @@ -84,7 +84,10 @@ When a registration contains one or more async inject methods, the container MUS |`ThreadSafeStrategy.None`|Allowed. The container MAY assign the task field directly without synchronization.| |`ThreadSafeStrategy.SemaphoreSlim`|Allowed. Singleton/scoped async-init services MUST use `WaitAsync()` / `Release()` around first initialization.| |`ThreadSafeStrategy.Lock`, `ThreadSafeStrategy.SpinLock`, or `ThreadSafeStrategy.CompareExchange`|Async-incompatible for async-init services and MUST NOT be used for that resolver path.| -|`EagerResolveOptions` includes singleton and/or scoped services|Async-init services MUST be excluded from eager resolution. The container constructor/scope constructor MUST NOT pre-start those tasks.| +|`EagerResolveOptions` and singleton/scoped services|`EagerResolveOptions` controls eager resolution of **both** synchronous singleton/scoped services **and** async-init singleton/scoped services.| +|Singleton async-init registration β€” eager init|When `EagerResolveOptions` includes `Singleton`, the container constructor MUST emit `_ = GetXxxAsync();` (fire-and-forget) to pre-start the task. Otherwise, no fire-and-forget call is emitted.| +|Scoped async-init registration β€” eager init|When `EagerResolveOptions` includes `Scoped`, the scope constructor MUST emit `_ = GetXxxAsync();` (fire-and-forget) to pre-start the task. Otherwise, no fire-and-forget call is emitted.| +|Transient async-init registration β€” eager init|Transient async-init services are **never** eagerly started. No fire-and-forget call is emitted.| |Collection wrappers (`IEnumerable`, `IReadOnlyCollection`, `IReadOnlyList`, `IList`, `T[]`)|Async-init registrations MUST be excluded from collection resolvers. `IEnumerable>` is not supported.| ```mermaid @@ -133,7 +136,7 @@ public sealed class FooBar : IFoo, IBar } } -[IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.SemaphoreSlim)] +[IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.SemaphoreSlim, EagerResolveOptions = EagerResolveOptions.Singleton)] public partial class AppContainer { public partial Task GetFooAsync(); @@ -151,8 +154,8 @@ partial class AppContainer { _fallbackProvider = fallbackProvider; - // Async-init singleton is excluded from eager resolution. - // _fooBar = GetFooBarAsync(); // MUST NOT be emitted + // Pre-start (fire-and-forget) the async-init singleton because `EagerResolveOptions` includes `Singleton`. + _ = GetFooBarAsync(); } private async global::System.Threading.Tasks.Task GetFooBarAsync() diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.KeyedServices.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.KeyedServices.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.KeyedServices.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.KeyedServices.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Lifetime.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Lifetime.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Lifetime.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Lifetime.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Options.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Options.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Options.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Options.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.PartialAccessors.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.PartialAccessors.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.PartialAccessors.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.PartialAccessors.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Performance.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Performance.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.Performance.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.Performance.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ThreadSafety.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.ThreadSafety.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Container.ThreadSafety.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Container.ThreadSafety.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Basic.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Basic.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Basic.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Basic.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Decorators.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Decorators.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Decorators.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Decorators.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Factory.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Factory.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Factory.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Factory.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Generics.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Generics.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Generics.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Generics.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.ImportModule.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.ImportModule.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.ImportModule.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.ImportModule.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Injection.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Injection.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Injection.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Injection.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.KeyValuePair.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.KeyValuePair.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.KeyValuePair.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.KeyValuePair.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.MSBuild.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.MSBuild.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.MSBuild.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.MSBuild.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.ServiceProviderInvocation.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.ServiceProviderInvocation.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.ServiceProviderInvocation.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.ServiceProviderInvocation.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Tags.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Tags.spec.md similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/Spec/Register.Tags.spec.md rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/Register.Tags.spec.md diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/SPEC.spec.md b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/SPEC.spec.md new file mode 100644 index 0000000..b224993 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Spec/SPEC.spec.md @@ -0,0 +1,212 @@ +# IocSourceGenerator Specification + +`IocSourceGenerator` is an `IIncrementalGenerator` that emits two kinds of output from `Ioc*` attributes (and `IServiceProvider.GetService` invocations) in user code: + +- **Register output** (`{assemblyName}.ServiceRegistration.g.cs`) β€” extension methods that register services into `IServiceCollection`. +- **Container output** (`{containerClassName}.Container.g.cs`) β€” standalone, `IServiceCollection`-free DI container partial classes. + +Per-feature documentation lives under `Spec/Register.*.spec.md` and `Spec/Container.*.spec.md` (see [Spec Index](#spec-index)). This document is the navigation map and architecture overview. + +--- + +## Pipeline at a glance + +```text +Stage 1 ForAttributeWithMetadataName + CreateSyntaxProvider + β”œβ”€β”€ [IocRegister] / [IocRegister] ─ Transforms/TransformRegister.cs + β”œβ”€β”€ [IocRegisterFor] / [IocRegisterFor] ─ Transforms/TransformRegister.cs + β”œβ”€β”€ [IocRegisterDefaults] / [IocRegisterDefaults]─ Transforms/TransformDefaultSettings.cs + β”œβ”€β”€ [IocImportModule] / [IocImportModule] ─ Transforms/TransformImportModule.cs + β”œβ”€β”€ [IocDiscover] / [IocDiscover] ─ Transforms/TransformDiscover.cs + β”œβ”€β”€ [IocContainer] ─ Transforms/TransformContainer.cs + └── IServiceProvider.GetService invocations ─ Transforms/IServiceProviderInvocations.cs + ↓ +Stage 2 Compilation / MSBuild / default-settings inputs + (BuildCompilationInfoProvider, BuildMsBuildPropertiesProvider, BuildDefaultLifetimeProvider) + ↓ +Stage 3 Per-registration processing (cacheable per registration) + ProcessSingleRegistration(RegistrationData, DefaultSettingsMap, ct) β†’ BasicRegistrationResult + ↓ +Stage 4 Closed-generic resolution + grouping + CombineAndResolveClosedGenerics β†’ ImmutableEquatableArray + β”œβ”€β†’ GroupRegistrationsForRegister β†’ RegisterOutputModel + └─→ FilterRegistrationsForContainer β†’ GroupRegistrationsForContainer β†’ ContainerWithGroups + ↓ +Stage 5 Emit + GenerateRegisterOutput β†’ {assemblyName}.ServiceRegistration.g.cs + GenerateContainerOutput β†’ {containerClassName}.Container.g.cs +``` + +The container branch splits on `ContainerModel.ExplicitOnly`: + +- **ExplicitOnly** uses `GroupExplicitOnlyRegistrations` and is independent of `serviceRegistrations` (its own caching branch). +- **Normal** combines with `serviceRegistrations`, then `Select(FilterRegistrationsForContainer)` acts as a caching barrier so unrelated `serviceRegistrations` changes do not invalidate downstream nodes. + +--- + +## Source layout + +```text +src/SourceGen.Ioc.SourceGenerator/ +β”œβ”€β”€ IocSourceGenerator.cs # Initialize() β€” pipeline wiring +β”œβ”€β”€ IocSourceGenerator.ConfigProviders.cs # MSBuild / Compilation / default-lifetime helpers +β”œβ”€β”€ Constants.cs (in Models/) +β”œβ”€β”€ RoslynExtensions.cs Β· TypeArgMap.cs Β· GlobalUsings.cs +β”œβ”€β”€ Transforms/ β€” Stage 1: attribute symbol β†’ data model (+ TransformExtensions.cs helpers) +β”œβ”€β”€ Processing/ β€” Stage 3 + part of Stage 4 (ProcessSingleRegistration, CombineAndResolveClosedGenerics) +β”œβ”€β”€ Grouping/ β€” Stage 4 grouping (GroupRegistrationsForRegister, GroupRegistrationsForContainer) +β”œβ”€β”€ Emit/ +β”‚ β”œβ”€β”€ Shared/ (CodeGenHelpers, SourceWriterExtensions, FeatureFilterHelper) +β”‚ β”œβ”€β”€ Register/ (GenerateRegisterOutput + RegisterOutputModel + RegisterEntry hierarchy + wrapper helpers) +β”‚ └── Container/ (GenerateContainerOutput + ContainerEntry hierarchy + ResolvedDependency + injection/interface/async helpers) +β”œβ”€β”€ Models/ β€” pipeline-shared immutable data models +β”œβ”€β”€ Analyzer/ β€” diagnostic analyzers (separate concern; see Analyzer/Spec/SPEC.spec.md) +└── Spec/ β€” this file + per-feature specs +``` + +> **Framework requirements** β€” The generator assembly targets `netstandard2.0`. Generated code and consumer projects require `.NET 10` (the `SourceGen.Ioc` runtime library targets `net10.0`). + +All files outside `Analyzer/` are `partial class IocSourceGenerator` in namespace `SourceGen.Ioc`. Models use namespace `SourceGen.Ioc.SourceGenerator.Models`. + +--- + +## Architecture pattern: discriminated unions for emission + +Both Register and Container pipelines pre-compute a discriminated union of "entry" instances during Stage 4. Stage 5 then walks them with polymorphic `Write*` methods β€” emission is pure string formatting, no shared mutable context. + +| Pipeline | Entry base | Subtypes | Write methods | +|---|---|---|---| +| Register | `RegisterEntry` | 7 (`Simple`, `Instance`, `Forwarding`, `Factory`, `Injection`, `AsyncInjection`, `Decorator`) + 3 wrapper structs (`Lazy`, `Func`, `Kvp`RegistrationEntry) | `WriteRegistration(SourceWriter, RegisterWriteContext)` | +| Container | `ContainerEntry` | 10 (6 service via `ServiceContainerEntry`, 3 wrapper, 1 collection) | `WriteField`, `WriteResolver`, `WriteEagerInit`, `WriteDisposal`, `WriteInit`, `WriteCollectionResolver`, `WriteLocalResolverEntries` | + +Container dependency lookups are pre-resolved into the **`ResolvedDependency`** hierarchy (18 subtypes such as `DirectServiceDependency`, `LazyInlineDependency`, `MultiParamFuncDependency`, `KvpInlineDependency`, `FallbackProviderDependency`, …) each implementing `FormatExpression(bool isOptional): string`. See `Emit/Container/ContainerEntry.cs` and `Emit/Container/ResolvedDependency.cs` for the full tables. + +--- + +## Key data models + +All under `Models/` unless noted; all immutable with value equality. + +| Model | Purpose | +|---|---| +| `RegistrationData` | Raw per-attribute payload (Stage 1 output) | +| `DefaultSettingsModel` / `DefaultSettingsMap` / `DefaultSettingsResult` | Defaults matching + multi-payload defaults result | +| `ImportModuleResult` | Per-module imported defaults + open generics | +| `BasicRegistrationResult` | Cacheable Stage 3 intermediate | +| `ServiceRegistrationModel` / `ServiceRegistrationWithTags` | Stage 4 registration record (+ tags) | +| `ContainerModel` | `[IocContainer]` input | +| `ContainerWithGroups` | Container + pre-grouped `ContainerRegistrationGroups` | +| `ContainerRegistrationGroups` | Grouped `ContainerEntry` arrays + `LastWinsByServiceType` lookup | +| `RegisterOutputModel` (`Emit/Register/`) | Top-level Register output model with per-tag `RegisterTagGroup` | +| `MsBuildProperties` | `RootNamespace`, `CustomIocName`, `Features` | +| `IocFeatures` | Feature-flag enum (`Register`, `Container`, `PropertyInject`, `FieldInject`, `MethodInject`, `AsyncMethodInject`) | +| `TypeData` and subtypes (`GenericTypeData`, `WrapperTypeData` family) | Type representation including wrapper-kind hierarchy (see [Wrapper kinds](#wrapper-kinds)) | + +--- + +## Inputs and configuration + +### Attributes + +| Attribute | Purpose | +|---|---| +| `IocRegisterAttribute()` | Mark a class for registration | +| `IocRegisterForAttribute()` | Register an external type | +| `IocRegisterDefaultsAttribute()` | Default settings + bulk `ImplementationTypes` | +| `IocImportModuleAttribute()` | Import another assembly's defaults / open generics | +| `IocDiscoverAttribute()` | Explicit closed-generic discovery | +| `IocContainerAttribute` | Mark a `partial class` as a generated container | +| `IocGenericFactoryAttribute` | Map generic factory type parameters | + +### MSBuild properties + +| Property | Default | Effect | +|---|---|---| +| `RootNamespace` | _assembly name_ | Namespace for generated output | +| `SourceGenIocName` | _assembly name_ | Override the generated method/class base name; if unset, falls back to assembly name (`"Generated"` when assembly name is absent) | +| `SourceGenIocDefaultLifetime` | `Transient` | Default lifetime when the attribute does not specify one | +| `SourceGenIocFeatures` | `Register,Container,PropertyInject,MethodInject` | Comma-separated feature flags (case-insensitive). `AsyncMethodInject` requires `MethodInject` (analyzer `SGIOC026`). | + +### IServiceProvider invocations collected + +`GetService`, `GetRequiredService`, `GetKeyedService`, `GetRequiredKeyedService`, `GetServices`, `GetKeyedServices` and their non-generic overloads. + +--- + +## Parse logic essentials + +| Topic | Rule | +|---|---| +| **Settings merge order** | explicit attribute β†’ matching defaults β†’ MSBuild `SourceGenIocDefaultLifetime` β†’ `Transient` | +| **Defaults match priority** | (1) implementation type itself, (2) closest base class, (3) first interface in `AllInterfaces` | +| **`ImplementationTypes` service derivation** | open generic β†’ `TargetServiceType` + configured `ServiceTypes`; closed/non-generic β†’ matched closed types from `AllInterfaces`/`AllBaseClasses`, falling back to `TargetServiceType` if no match (e.g. when framework metadata is invisible) | +| **Key interpretation** | `KeyType=Value` β†’ literal; `KeyType=Csharp` β†’ C# expression (`MyClass.Field`, `nameof(...)`) | +| **Inject attribute matching** | by name only β€” `IocInjectAttribute` or `InjectAttribute` (any namespace, e.g. `Microsoft.AspNetCore.Components.InjectAttribute`) | +| **Constructor selection** | (1) `[IocInject]`-marked, (2) primary, (3) most parameters | +| **Parameter resolution** | `[ServiceKey]` β†’ registration key; `[FromKeyedServices]`/`[IocInject(Key=…)]` β†’ keyed; `IServiceProvider` β†’ pass through; collection types β†’ element service type | + +### Wrapper kinds + +`WrapperKind` is a unified enum; each value has a dedicated `WrapperTypeData` subtype. + +```text +TypeData +└── GenericTypeData + └── WrapperTypeData + β”œβ”€β”€ CollectionWrapperTypeData β†’ Enumerable / ReadOnlyCollection / Collection / ReadOnlyList / List / Array + β”œβ”€β”€ LazyTypeData + β”œβ”€β”€ FuncTypeData + β”œβ”€β”€ TaskTypeData (resolves Task for async-init, Task.FromResult(...) otherwise) + β”œβ”€β”€ DictionaryTypeData + └── KeyValuePairTypeData +``` + +Wrappers nest: `IEnumerable>` parses to `EnumerableTypeData(LazyTypeData(IMyService))`. + +### Generic factory type mapping + +`IocGenericFactoryAttribute` maps service-type parameters to factory-method type parameters. Example for `IRequestHandler<,>`: + +```csharp +[IocGenericFactory(typeof(IRequestHandler, decimal>), typeof(decimal), typeof(int))] +public static IRequestHandler Create() => new Handler(); +// decimal β†’ T1, int β†’ T2 +``` + +--- + +## Spec index + +### Registration features + +| Feature | File | +|---|---| +| Basic registration | [Register.Basic.spec.md](Register.Basic.spec.md) | +| Decorators | [Register.Decorators.spec.md](Register.Decorators.spec.md) | +| Tags | [Register.Tags.spec.md](Register.Tags.spec.md) | +| Injection members | [Register.Injection.spec.md](Register.Injection.spec.md) | +| Imported modules | [Register.ImportModule.spec.md](Register.ImportModule.spec.md) | +| Open generics | [Register.Generics.spec.md](Register.Generics.spec.md) | +| `IServiceProvider` invocations | [Register.ServiceProviderInvocation.spec.md](Register.ServiceProviderInvocation.spec.md) | +| MSBuild configuration | [Register.MSBuild.spec.md](Register.MSBuild.spec.md) | +| Factory & instance | [Register.Factory.spec.md](Register.Factory.spec.md) | +| `KeyValuePair` | [Register.KeyValuePair.spec.md](Register.KeyValuePair.spec.md) | + +### Container features + +| Feature | File | +|---|---| +| Basic container | [Container.Basic.spec.md](Container.Basic.spec.md) | +| Service lifetime | [Container.Lifetime.spec.md](Container.Lifetime.spec.md) | +| Keyed services | [Container.KeyedServices.spec.md](Container.KeyedServices.spec.md) | +| Injection | [Container.Injection.spec.md](Container.Injection.spec.md) | +| Decorators | [Container.Decorators.spec.md](Container.Decorators.spec.md) | +| Imported modules | [Container.ImportModule.spec.md](Container.ImportModule.spec.md) | +| Factory & instance | [Container.Factory.spec.md](Container.Factory.spec.md) | +| Open generics | [Container.Generics.spec.md](Container.Generics.spec.md) | +| Collections & wrappers | [Container.Collections.spec.md](Container.Collections.spec.md) | +| Container options | [Container.Options.spec.md](Container.Options.spec.md) | +| Thread safety | [Container.ThreadSafety.spec.md](Container.ThreadSafety.spec.md) | +| Partial accessors | [Container.PartialAccessors.spec.md](Container.PartialAccessors.spec.md) | +| MVC & Blazor | [Container.AspNetCore.spec.md](Container.AspNetCore.spec.md) | +| Performance | [Container.Performance.spec.md](Container.Performance.spec.md) | diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/IServiceProviderInvocations.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/IServiceProviderInvocations.cs similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/IServiceProviderInvocations.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/IServiceProviderInvocations.cs diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformContainer.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformContainer.cs similarity index 96% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformContainer.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformContainer.cs index 9770341..d63e490 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformContainer.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformContainer.cs @@ -242,21 +242,15 @@ private static ImmutableEquatableArray ExtractImportedModules(INamedTy if(fullName == Constants.IocImportModuleAttributeFullName) { - // Non-generic: ModuleType is in constructor argument - if(attr.ConstructorArguments.Length > 0 && - attr.ConstructorArguments[0].Value is INamedTypeSymbol moduleType) - { + var moduleType = attr.GetImportedModuleType(); + if(moduleType is not null) modules.Add(moduleType.GetTypeData()); - } } else if(originalFullName == Constants.IocImportModuleAttributeFullName_T1) { - // Generic: ModuleType is the type argument - if(attrClass.IsGenericType && attrClass.TypeArguments.Length > 0 && - attrClass.TypeArguments[0] is INamedTypeSymbol genericModuleType) - { + var genericModuleType = attr.GetImportedModuleType(); + if(genericModuleType is not null) modules.Add(genericModuleType.GetTypeData()); - } } } diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformDefaultSettings.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformDefaultSettings.cs similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformDefaultSettings.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformDefaultSettings.cs diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformDiscover.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformDiscover.cs similarity index 100% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformDiscover.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformDiscover.cs diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.AttributeArguments.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.AttributeArguments.cs new file mode 100644 index 0000000..f0a4b4c --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.AttributeArguments.cs @@ -0,0 +1,400 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(AttributeData attribute) + { + /// + /// Gets an array of type symbols from a named argument. + /// + /// The name of the named argument. + /// Whether to extract constructor parameters. + /// Whether to extract injection members (for decorators). + public ImmutableEquatableArray GetTypeArrayArgument( + string name, + bool extractConstructorParams = false, + bool extractInjectionMembers = false) + { + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key.Equals(name, StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) + { + List result = []; + foreach(var value in namedArg.Value.Values) + { + if(value.Value is INamedTypeSymbol namedTypeSymbol) + { + result.Add(namedTypeSymbol.GetTypeData( + extractConstructorParams, + extractHierarchy: false, + visited: null, + semanticModel: null, + extractInjectionMembers)); + } + else if(value.Value is ITypeSymbol typeSymbol) + { + result.Add(typeSymbol.GetTypeData(extractConstructorParams)); + } + } + return result.ToImmutableEquatableArray(); + } + } + + return []; + } + + /// + /// Gets an array of type symbols from an attribute constructor argument of type params Type[]. + /// This is the constructor-argument counterpart to , and is used when + /// service types are supplied positionally to the attribute constructor instead of via a named Type[] argument. + /// + /// + /// This method scans the attribute's constructor arguments for an array (or params) argument that contains + /// type values (for example, params Type[] serviceTypes) and converts those instances + /// to . It skips non-type arguments, such as ServiceLifetime enum values. + /// + public ImmutableEquatableArray GetTypeArrayFromConstructorArgument(bool extractConstructorParams = false) + { + foreach(var ctorArg in attribute.ConstructorArguments) + { + // Look for an array argument containing type values + if(ctorArg.Kind == TypedConstantKind.Array && !ctorArg.IsNull) + { + List result = []; + foreach(var value in ctorArg.Values) + { + if(value.Value is ITypeSymbol typeSymbol) + { + result.Add(typeSymbol.GetTypeData(extractConstructorParams)); + } + } + + // Only return if we found type values + if(result.Count > 0) + return result.ToImmutableEquatableArray(); + } + } + + return []; + } + + public (bool HasArg, ServiceLifetime Lifetime) TryGetLifetime() + { + // First, check if lifetime is passed as a constructor argument (for generic attributes like IoCRegisterAttribute(ServiceLifetime.Scoped)) + foreach(var ctorArg in attribute.ConstructorArguments) + { + if(ctorArg.Type?.Name == nameof(ServiceLifetime) && ctorArg.Value is int lifetimeValue) + { + return (true, (ServiceLifetime)lifetimeValue); + } + } + + // Fall back to named argument + var (hasArg, val) = attribute.TryGetNamedArgument("Lifetime", 2); // Default is ServiceLifetime.Transient + return (hasArg, (ServiceLifetime)val); + } + + public (bool HasArg, bool Value) TryGetRegisterAllInterfaces() => + attribute.TryGetNamedArgument("RegisterAllInterfaces", false); + + public (bool HasArg, bool Value) TryGetRegisterAllBaseClasses() => + attribute.TryGetNamedArgument("RegisterAllBaseClasses", false); + + /// + /// Gets the service types from the attribute. + /// This method checks both named arguments and constructor arguments for service types. + /// + /// + /// The method first checks for a named argument "ServiceTypes" (e.g., ServiceTypes = [typeof(IService)]). + /// If not found, it checks constructor arguments for an array of types (e.g., params Type[] serviceTypes). + /// + public IEnumerable GetServiceTypeSymbols() + { + // First, try to get from named argument + var namedResult = attribute.GetTypeSymbolsFromNamedArgument("ServiceTypes"); + if(namedResult.Length > 0) + return namedResult; + + // Fall back to constructor argument (params Type[] serviceTypes) + foreach(var ctorArg in attribute.ConstructorArguments) + { + if(ctorArg.Kind == TypedConstantKind.Array && !ctorArg.IsNull) + { + List result = []; + foreach(var value in ctorArg.Values) + { + if(value.Value is INamedTypeSymbol namedTypeSymbol) + { + result.Add(namedTypeSymbol); + } + } + + if(result.Count > 0) + return result; + } + } + + return []; + } + + public ImmutableEquatableArray GetServiceTypes() + { + List result = []; + foreach(var serviceTypeSymbol in attribute.GetServiceTypeSymbols()) + { + result.Add(serviceTypeSymbol.GetTypeData()); + } + + return result.ToImmutableEquatableArray(); + } + + /// + /// Gets the service types from generic attribute type parameters (e.g., IoCRegisterAttribute<T1, T2>). + /// + public IEnumerable GetServiceTypeSymbolsFromGenericAttribute() + { + var attrClass = attribute.AttributeClass; + if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) + return []; + + List result = []; + foreach(var typeArg in attrClass.TypeArguments) + { + if(typeArg is INamedTypeSymbol namedType) + { + result.Add(namedType); + } + } + + return result; + } + + public ImmutableEquatableArray GetServiceTypesFromGenericAttribute() + { + List result = []; + foreach(var serviceTypeSymbol in attribute.GetServiceTypeSymbolsFromGenericAttribute()) + { + result.Add(serviceTypeSymbol.GetTypeData()); + } + + return result.ToImmutableEquatableArray(); + } + + public INamedTypeSymbol? GetImportedModuleType() + { + var attributeClass = attribute.AttributeClass; + if(attributeClass is null) + return null; + + if(attributeClass.IsGenericType) + { + if(attributeClass.TypeArguments.Length == 0) + return null; + + return attributeClass.TypeArguments[0] as INamedTypeSymbol; + } + + if(attribute.ConstructorArguments.Length == 0) + return null; + + return attribute.ConstructorArguments[0].Value as INamedTypeSymbol; + } + + public ImmutableEquatableArray GetDecorators() => + attribute.GetTypeArrayArgument("Decorators", extractConstructorParams: true, extractInjectionMembers: true); + + /// + /// Gets the ImplementationTypes array from the attribute. + /// Extracts implementation types with constructor parameters and hierarchy information, + /// using the same parsing logic as IocRegisterAttribute. + /// + public ImmutableEquatableArray GetImplementationTypes() => + attribute.GetTypeArrayArgumentWithHierarchy("ImplementationTypes"); + + /// + /// Gets the ImplementationTypes array as INamedTypeSymbol from the attribute. + /// Used when full symbol access is needed for injection member extraction. + /// + public ImmutableEquatableArray GetImplementationTypeSymbols() => + attribute.GetTypeSymbolsFromNamedArgument("ImplementationTypes"); + + /// + /// Gets an array of type symbols from a named argument. + /// Used when full symbol access is needed for further analysis. + /// + public ImmutableEquatableArray GetTypeSymbolsFromNamedArgument(string name) + { + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key.Equals(name, StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) + { + List result = []; + foreach(var value in namedArg.Value.Values) + { + if(value.Value is INamedTypeSymbol namedTypeSymbol) + { + result.Add(namedTypeSymbol); + } + } + return result.ToImmutableEquatableArray(); + } + } + + return []; + } + + /// + /// Gets an array of type symbols from a named argument with full hierarchy extraction. + /// Used for ImplementationTypes where we need constructor params and all interfaces/base classes. + /// + public ImmutableEquatableArray GetTypeArrayArgumentWithHierarchy(string name) + { + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key.Equals(name, StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) + { + List result = []; + foreach(var value in namedArg.Value.Values) + { + if(value.Value is INamedTypeSymbol namedTypeSymbol) + { + // Extract with constructor params and hierarchy, same as IocRegisterAttribute + result.Add(namedTypeSymbol.GetTypeData(extractConstructorParams: true, extractHierarchy: true)); + } + } + return result.ToImmutableEquatableArray(); + } + } + + return []; + } + + /// + /// Gets the Tags array from the attribute. + /// + public ImmutableEquatableArray GetTags() + { + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key.Equals("Tags", StringComparison.Ordinal) && !namedArg.Value.IsNull && namedArg.Value.Kind == TypedConstantKind.Array) + { + List result = []; + foreach(var value in namedArg.Value.Values) + { + if(value.Value is string tag) + { + result.Add(tag); + } + } + return result.ToImmutableEquatableArray(); + } + } + + return []; + } + + /// + /// Checks if the attribute has Factory or Instance specified. + /// + /// A tuple indicating whether Factory and/or Instance are specified. + public (bool HasFactory, bool HasInstance) HasFactoryOrInstance() + { + bool hasFactory = false; + bool hasInstance = false; + + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key == "Factory" && !namedArg.Value.IsNull) + { + hasFactory = true; + } + else if(namedArg.Key == "Instance" && !namedArg.Value.IsNull) + { + hasInstance = true; + } + + // Early exit if both found + if(hasFactory && hasInstance) + break; + } + + return (hasFactory, hasInstance); + } + + /// + /// Gets the target type from an IoCRegisterForAttribute. + /// For non-generic variant, extracts from constructor argument. + /// For generic variant (IoCRegisterForAttribute<T>), extracts from type parameter. + /// + /// The target type symbol, or null if not found. + public INamedTypeSymbol? GetTargetTypeFromRegisterForAttribute() + { + var attributeClass = attribute.AttributeClass; + if(attributeClass is null) + return null; + + // For generic IoCRegisterForAttribute, get T from type arguments + if(attributeClass.IsGenericType && attributeClass.TypeArguments.Length > 0) + { + return attributeClass.TypeArguments[0] as INamedTypeSymbol; + } + + // For non-generic IoCRegisterForAttribute, get from constructor argument + if(attribute.ConstructorArguments.Length > 0 && + attribute.ConstructorArguments[0].Value is INamedTypeSymbol targetType) + { + return targetType; + } + + return null; + } + + /// + /// Gets the Instance path from the attribute. + /// + /// Optional semantic model to resolve full access paths for nameof() expressions. + /// The static instance path (e.g., "MyService.Default"), or null if not specified. + public string? GetInstance(SemanticModel? semanticModel = null) + { + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key == "Instance") + { + if(namedArg.Value.IsNull) + return null; + + // Try to get original syntax for nameof() expressions with full access path resolution + return attribute.TryGetNameof("Instance", semanticModel) + ?? namedArg.Value.Value?.ToString(); + } + } + + return null; + } + + /// + /// Determines if the attribute will cause registration of interfaces or base classes. + /// For open generic types, nested open generics are only a problem when registering interfaces/base classes. + /// + public bool WillRegisterInterfacesOrBaseClasses() + { + // Check if ServiceTypes is specified + var serviceTypes = attribute.GetServiceTypes(); + if(serviceTypes.Length > 0) + return true; + + // Check if RegisterAllInterfaces is true + var (hasRegisterAllInterfaces, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); + if(hasRegisterAllInterfaces && registerAllInterfaces) + return true; + + // Check if RegisterAllBaseClasses is true + var (hasRegisterAllBaseClasses, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); + if(hasRegisterAllBaseClasses && registerAllBaseClasses) + return true; + + // Only registering self, no interfaces/base classes + return false; + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.Constructors.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.Constructors.cs new file mode 100644 index 0000000..0b1d335 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.Constructors.cs @@ -0,0 +1,126 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(INamedTypeSymbol typeSymbol) + { + public IMethodSymbol? SpecifiedOrPrimaryOrMostParametersConstructor + { + get + { + IMethodSymbol? injectCtor = null; + IMethodSymbol? primaryCtor = null; + IMethodSymbol? bestCtor = null; + int maxParameters = -1; + foreach(var ctor in typeSymbol.Constructors) + { + if(ctor.IsImplicitlyDeclared) + continue; + + if(ctor.IsStatic) + continue; + + if(ctor.DeclaredAccessibility is not (Accessibility.Public or Accessibility.Internal)) + continue; + + // IocInjectAttribute/InjectAttribute specified constructor - highest priority + if(ctor.GetAttributes().Any(attr => attr.AttributeClass?.IsInject == true)) + { + injectCtor = ctor; + continue; + } + + var syntaxRef = ctor.DeclaringSyntaxReferences.FirstOrDefault(); + // Primary constructor - second priority + if(syntaxRef?.GetSyntax() is TypeDeclarationSyntax) + { + primaryCtor = ctor; + continue; + } + + // Find constructor with most parameters - lowest priority + if(ctor.Parameters.Length > maxParameters) + { + maxParameters = ctor.Parameters.Length; + bestCtor = ctor; + } + } + + // Return by priority: [Inject] > primary > most parameters + return injectCtor ?? primaryCtor ?? bestCtor; + } + } + + /// + /// Extracts constructor parameters from a type and indicates whether the constructor was selected by [Inject] attribute. + /// + /// Set of visited types to prevent infinite recursion. + /// Optional semantic model for resolving nameof() expressions in service keys. Only used for top-level extraction, not passed to recursive calls. + /// A tuple containing the constructor parameters and whether the constructor has [Inject] attribute. + public (ImmutableEquatableArray Parameters, bool HasInjectConstructor) ExtractConstructorParametersWithInfo( + HashSet? visited = null, + SemanticModel? semanticModel = null) + { + // Check if we've already visited this type to prevent infinite recursion + if(visited is not null && !visited.Add(typeSymbol)) + { + return ([], false); + } + + // Get the original definition for open generic types to access constructors + var typeToInspect = typeSymbol.IsGenericType && typeSymbol.IsDefinition + ? typeSymbol + : typeSymbol.OriginalDefinition ?? typeSymbol; + + // Get the constructor: [Inject] marked > primary constructor > most parameters + var constructor = typeToInspect.SpecifiedOrPrimaryOrMostParametersConstructor; + if(constructor is null) + { + return ([], false); + } + + // Check if the selected constructor has [IocInject] or [Inject] attribute + bool hasInjectConstructor = constructor.GetAttributes() + .Any(static attr => attr.AttributeClass?.IsInject == true); + + visited ??= new HashSet(SymbolEqualityComparer.Default); + List parameters = []; + foreach(var param in constructor.Parameters) + { + var paramType = param.Type; + + // Get TypeData using the unified method with recursive constructor extraction + // Also extract hierarchy (interfaces) for generic types to enable IEnumerable detection + var paramTypeData = paramType is INamedTypeSymbol namedParamType + ? namedParamType.GetTypeData(extractConstructorParams: true, extractHierarchy: namedParamType.IsGenericType, visited: visited) + : paramType.GetTypeData(); + + // Check if parameter type is nullable (e.g., IDependency?) + var isNullable = param.NullableAnnotation == NullableAnnotation.Annotated; + + // Check if parameter has an explicit default value (for skipping unresolvable parameters) + var hasDefaultValue = param.HasExplicitDefaultValue; + + // Check for [FromKeyedServices], [Inject], or [ServiceKey] attribute + // SemanticModel is used to resolve nameof() expressions for top-level parameters + var (serviceKey, hasInjectAttribute, hasServiceKeyAttribute, hasFromKeyedServicesAttribute) = param.GetServiceKeyAndAttributeInfo(semanticModel); + + // Get the C# code representation of the default value + var defaultValue = hasDefaultValue ? ToDefaultValueCodeString(param.ExplicitDefaultValue) : null; + + parameters.Add(new ParameterData( + param.Name, + paramTypeData, + IsNullable: isNullable, + HasDefaultValue: hasDefaultValue, + DefaultValue: defaultValue, + ServiceKey: serviceKey, + HasInjectAttribute: hasInjectAttribute, + HasServiceKeyAttribute: hasServiceKeyAttribute, + HasFromKeyedServicesAttribute: hasFromKeyedServicesAttribute)); + } + + return (parameters.ToImmutableEquatableArray(), hasInjectConstructor); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.DecoratorInjection.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.DecoratorInjection.cs new file mode 100644 index 0000000..016a844 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.DecoratorInjection.cs @@ -0,0 +1,181 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(INamedTypeSymbol typeSymbol) + { + /// + /// Extracts injection members (properties, fields, methods with [IocInject]/[Inject] attributes) from the type. + /// This is used for both regular registrations and decorators. + /// + /// Optional semantic model to resolve full access paths for nameof() expressions. + /// An array of injection member data. + public ImmutableEquatableArray ExtractInjectionMembersForDecorator(SemanticModel? semanticModel = null) + { + List? injectionMembers = null; + + foreach(var (member, injectAttribute) in typeSymbol.GetInjectedMembers()) + { + // Extract key information from IocInjectAttribute/InjectAttribute + var (key, _, _) = injectAttribute.GetKeyInfo(semanticModel); + + InjectionMemberData? memberData = member switch + { + IPropertySymbol property => CreateDecoratorPropertyInjection(property, key), + IFieldSymbol field => CreateDecoratorFieldInjection(field, key), + IMethodSymbol method => CreateDecoratorMethodInjection(method, key, semanticModel), + _ => null + }; + + if(memberData is not null) + { + injectionMembers ??= []; + injectionMembers.Add(memberData); + } + } + + return injectionMembers?.ToImmutableEquatableArray() ?? []; + } + + private static InjectionMemberData CreateDecoratorPropertyInjection(IPropertySymbol property, string? key) + { + var propertyType = property.Type.GetTypeData(); + var isNullable = property.NullableAnnotation == NullableAnnotation.Annotated; + + // Try to get the default value from property initializer + var (hasDefaultValue, defaultValue) = GetDecoratorPropertyDefaultValue(property); + + return new InjectionMemberData( + InjectionMemberType.Property, + property.Name, + propertyType, + null, + key, + isNullable, + hasDefaultValue, + defaultValue); + } + + private static InjectionMemberData CreateDecoratorFieldInjection(IFieldSymbol field, string? key) + { + var fieldType = field.Type.GetTypeData(); + var isNullable = field.NullableAnnotation == NullableAnnotation.Annotated; + + // Try to get the default value from field initializer + var (hasDefaultValue, defaultValue) = GetDecoratorFieldDefaultValue(field); + + return new InjectionMemberData( + InjectionMemberType.Field, + field.Name, + fieldType, + null, + key, + isNullable, + hasDefaultValue, + defaultValue); + } + + private static (bool HasDefaultValue, string? DefaultValue) GetDecoratorPropertyDefaultValue(IPropertySymbol property) + { + var syntaxRef = property.DeclaringSyntaxReferences.FirstOrDefault(); + if(syntaxRef?.GetSyntax() is not PropertyDeclarationSyntax propertySyntax) + return (false, null); + + var initializer = propertySyntax.Initializer; + if(initializer is null) + return (false, null); + + // Check if it's a null literal or null-forgiving expression (null!) + if(IsDecoratorNullExpression(initializer.Value)) + { + return (true, null); + } + + return (true, initializer.Value.ToString()); + } + + private static (bool HasDefaultValue, string? DefaultValue) GetDecoratorFieldDefaultValue(IFieldSymbol field) + { + var syntaxRef = field.DeclaringSyntaxReferences.FirstOrDefault(); + var syntax = syntaxRef?.GetSyntax(); + + // Field can be declared in VariableDeclaratorSyntax + EqualsValueClauseSyntax? initializer = syntax switch + { + VariableDeclaratorSyntax variableDeclarator => variableDeclarator.Initializer, + _ => null + }; + + if(initializer is null) + return (false, null); + + // Check if it's a null literal or null-forgiving expression (null!) + if(IsDecoratorNullExpression(initializer.Value)) + { + return (true, null); + } + + return (true, initializer.Value.ToString()); + } + + private static bool IsDecoratorNullExpression(ExpressionSyntax expression) + { + // Direct null literal + if(expression is LiteralExpressionSyntax literal && + literal.Kind() == SyntaxKind.NullLiteralExpression) + { + return true; + } + + // Null-forgiving expression: null! + if(expression is PostfixUnaryExpressionSyntax postfix && + postfix.Kind() == SyntaxKind.SuppressNullableWarningExpression && + postfix.Operand is LiteralExpressionSyntax innerLiteral && + innerLiteral.Kind() == SyntaxKind.NullLiteralExpression) + { + return true; + } + + return false; + } + + private static InjectionMemberData CreateDecoratorMethodInjection(IMethodSymbol method, string? key, SemanticModel? semanticModel) + { + var parameters = method.Parameters + .Select(p => + { + var (serviceKey, hasInjectAttribute, hasServiceKeyAttribute, hasFromKeyedServicesAttribute) = p.GetServiceKeyAndAttributeInfo(semanticModel); + return new ParameterData( + p.Name, + p.Type.GetTypeData(), + IsNullable: p.NullableAnnotation == NullableAnnotation.Annotated, + HasDefaultValue: p.HasExplicitDefaultValue, + DefaultValue: p.HasExplicitDefaultValue ? DecoratorToDefaultValueCodeString(p.ExplicitDefaultValue) : null, + ServiceKey: serviceKey, + HasInjectAttribute: hasInjectAttribute, + HasServiceKeyAttribute: hasServiceKeyAttribute, + HasFromKeyedServicesAttribute: hasFromKeyedServicesAttribute); + }) + .ToImmutableEquatableArray(); + + return new InjectionMemberData( + InjectionMemberType.Method, + method.Name, + null, + parameters, + key); + } + + private static string? DecoratorToDefaultValueCodeString(object? value) + { + return value switch + { + null => null, + string s => $"\"{s}\"", + char c => $"'{c}'", + bool b => b ? "true" : "false", + _ => value.ToString() + }; + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.DefaultSettings.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.DefaultSettings.cs new file mode 100644 index 0000000..9d7bb1d --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.DefaultSettings.cs @@ -0,0 +1,100 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(AttributeData attribute) + { + /// + /// Extracts default settings from an IoCRegisterDefaultSettingsAttribute. + /// + /// Optional semantic model for resolving Factory method data. + /// The default settings model, or null if the attribute data is invalid. + public DefaultSettingsModel? ExtractDefaultSettings(SemanticModel? semanticModel = null) + { + if(attribute.ConstructorArguments.Length < 2) + return null; + if(attribute.ConstructorArguments[0].Value is not INamedTypeSymbol targetServiceType) + return null; + if(attribute.ConstructorArguments[1].Value is not int lifetime) + return null; + + var (_, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); + var (_, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); + var serviceTypes = attribute.GetServiceTypes(); + var typeData = targetServiceType.GetTypeData(); + var decorators = attribute.GetDecorators(); + var tags = attribute.GetTags(); + + // Get factory method data if semantic model is provided + FactoryMethodData? factory = null; + if(semanticModel is not null) + { + factory = attribute.GetFactoryMethodData(semanticModel); + } + + // Get implementation types with constructor params and hierarchy (same as IocRegisterAttribute) + var implementationTypes = attribute.GetImplementationTypes(); + + return new DefaultSettingsModel( + typeData, + (ServiceLifetime)lifetime, + registerAllInterfaces, + registerAllBaseClasses, + serviceTypes, + decorators, + tags, + factory, + implementationTypes); + } + + /// + /// Extracts default settings from a generic IoCRegisterDefaultsAttribute (e.g., IoCRegisterDefaultsAttribute<T>). + /// The target service type is specified via type parameter instead of constructor argument. + /// + /// Optional semantic model for resolving Factory method data. + /// The default settings model, or null if the attribute data is invalid. + public DefaultSettingsModel? ExtractDefaultSettingsFromGenericAttribute(SemanticModel? semanticModel = null) + { + var attrClass = attribute.AttributeClass; + if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) + return null; + + if(attrClass.TypeArguments[0] is not INamedTypeSymbol targetServiceType) + return null; + + // Lifetime is the first constructor argument for the generic version + if(attribute.ConstructorArguments.Length < 1) + return null; + if(attribute.ConstructorArguments[0].Value is not int lifetime) + return null; + + var (_, registerAllInterfaces) = attribute.TryGetRegisterAllInterfaces(); + var (_, registerAllBaseClasses) = attribute.TryGetRegisterAllBaseClasses(); + var serviceTypes = attribute.GetServiceTypes(); + var typeData = targetServiceType.GetTypeData(); + var decorators = attribute.GetDecorators(); + var tags = attribute.GetTags(); + + // Get factory method data if semantic model is provided + FactoryMethodData? factory = null; + if(semanticModel is not null) + { + factory = attribute.GetFactoryMethodData(semanticModel); + } + + // Get implementation types with constructor params and hierarchy (same as IocRegisterAttribute) + var implementationTypes = attribute.GetImplementationTypes(); + + return new DefaultSettingsModel( + typeData, + (ServiceLifetime)lifetime, + registerAllInterfaces, + registerAllBaseClasses, + serviceTypes, + decorators, + tags, + factory, + implementationTypes); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.FactoryMethod.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.FactoryMethod.cs new file mode 100644 index 0000000..47b2457 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.FactoryMethod.cs @@ -0,0 +1,296 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(AttributeData attribute) + { + /// + /// Extracts from the registration attribute's + /// GenericFactoryTypeMapping named property. + /// Used as a fallback when [IocGenericFactory] is not present on the factory method. + /// + /// The generic factory type mapping, or null if not specified or invalid. + public GenericFactoryTypeMapping? ExtractGenericFactoryMappingFromAttributeProperty() + { + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key != "GenericFactoryTypeMapping") + continue; + + if(namedArg.Value.Kind != TypedConstantKind.Array || namedArg.Value.IsNull) + return null; + + var typeArray = namedArg.Value.Values; + if(typeArray.Length < 2) + return null; + + if(typeArray[0].Value is not INamedTypeSymbol serviceTypeTemplate) + return null; + + var serviceTypeTemplateData = serviceTypeTemplate.GetTypeData(); + + var placeholderMap = new Dictionary(StringComparer.Ordinal); + for(int i = 1; i < typeArray.Length; i++) + { + if(typeArray[i].Value is ITypeSymbol placeholderType) + { + var placeholderTypeName = placeholderType.FullyQualifiedName; + if(placeholderMap.ContainsKey(placeholderTypeName)) + return null; // Duplicate placeholder + placeholderMap[placeholderTypeName] = i - 1; + } + } + + if(placeholderMap.Count != typeArray.Length - 1) + return null; + + return new GenericFactoryTypeMapping( + serviceTypeTemplateData, + placeholderMap.ToImmutableEquatableDictionary()); + } + + return null; + } + + /// + /// Gets the Factory method data from the attribute, including parameter and return type information. + /// When the resolved factory method is generic but has no [IocGenericFactory] attribute, + /// falls back to the GenericFactoryTypeMapping property on the registration attribute. + /// + /// Semantic model to resolve method symbols. + /// The factory method data, or null if not specified. + public FactoryMethodData? GetFactoryMethodData(SemanticModel semanticModel) + { + var syntaxReference = attribute.ApplicationSyntaxReference; + if(syntaxReference?.GetSyntax() is not AttributeSyntax attributeSyntax) + return null; + + var argumentList = attributeSyntax.ArgumentList; + if(argumentList is null) + return null; + + foreach(var argument in argumentList.Arguments) + { + if(argument.NameEquals?.Name.Identifier.Text != "Factory") + continue; + + // Check if the expression is a nameof() invocation + if(argument.Expression is InvocationExpressionSyntax invocation && + invocation.Expression is IdentifierNameSyntax identifierName && + identifierName.Identifier.Text == "nameof" && + invocation.ArgumentList.Arguments.Count == 1) + { + var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; + var methodSymbol = ResolveMethodSymbol(nameofArgument, semanticModel); + + if(methodSymbol is not null) + { + var factoryData = CreateFactoryMethodData(methodSymbol, semanticModel); + + // Fallback: if method is generic but has no [IocGenericFactory], check attribute's GenericFactoryTypeMapping + if(factoryData.GenericTypeMapping is null && methodSymbol.TypeParameters.Length > 0) + { + var mappingFromAttr = attribute.ExtractGenericFactoryMappingFromAttributeProperty(); + if(mappingFromAttr is not null) + factoryData = factoryData with { GenericTypeMapping = mappingFromAttr }; + } + + return factoryData; + } + + // Fallback: get path from nameof expression + var nameofPath = ResolveNameofExpression(nameofArgument, semanticModel) + ?? nameofArgument.ToFullString().Trim(); + return new FactoryMethodData(nameofPath, HasServiceProvider: true, HasKey: false, ReturnTypeName: null, AdditionalParameters: []); + } + + // String literal - cannot determine parameters, assume full signature + if(argument.Expression is LiteralExpressionSyntax literal && + literal.Token.Value is string literalPath) + { + return new FactoryMethodData(literalPath, HasServiceProvider: true, HasKey: false, ReturnTypeName: null, AdditionalParameters: []); + } + } + + return null; + } + } + + extension(IMethodSymbol methodSymbol) + { + /// + /// Creates FactoryMethodData from a method symbol. + /// Analyzes factory method parameters: + /// - IServiceProvider: Will be passed the service provider directly + /// - [ServiceKey] attribute: Will be passed the registration key value + /// - Other parameters: Will be resolved from the service provider using the same logic as [IocInject] methods + /// Also extracts [IocGenericFactory] attribute if present for generic factory method support. + /// + public FactoryMethodData CreateFactoryMethodData(SemanticModel? semanticModel = null) + { + var path = methodSymbol.FullAccessPath; + bool hasServiceProvider = false; + bool hasKey = false; + List? additionalParameters = null; + + foreach(var param in methodSymbol.Parameters) + { + var paramTypeName = param.Type.FullyQualifiedName; + + // Check for IServiceProvider + if(paramTypeName is "global::System.IServiceProvider" or "System.IServiceProvider") + { + hasServiceProvider = true; + continue; + } + + // Check for [ServiceKey] attribute + bool hasServiceKeyAttribute = false; + foreach(var attribute in param.GetAttributes()) + { + var attrClass = attribute.AttributeClass; + if(attrClass is null) + continue; + + if(attrClass.Name == "ServiceKeyAttribute" + && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") + { + hasServiceKeyAttribute = true; + hasKey = true; + break; + } + } + + // Skip [ServiceKey] parameters from additional parameters + if(hasServiceKeyAttribute) + continue; + + // Collect additional parameter info using the same logic as [IocInject] methods + var (serviceKey, hasInjectAttribute, _, hasFromKeyedServicesAttribute) = param.GetServiceKeyAndAttributeInfo(semanticModel); + var parameterData = new ParameterData( + param.Name, + param.Type.GetTypeData(), + IsNullable: param.NullableAnnotation == NullableAnnotation.Annotated, + HasDefaultValue: param.HasExplicitDefaultValue, + DefaultValue: param.HasExplicitDefaultValue ? ToDefaultValueCodeString(param.ExplicitDefaultValue) : null, + ServiceKey: serviceKey, + HasInjectAttribute: hasInjectAttribute, + HasServiceKeyAttribute: false, // Already handled above + HasFromKeyedServicesAttribute: hasFromKeyedServicesAttribute); + + additionalParameters ??= []; + additionalParameters.Add(parameterData); + } + + // Always store the return type for runtime comparison + var returnTypeName = methodSymbol.ReturnType.FullyQualifiedName; + + // Extract [IocGenericFactory] attribute if present + var genericTypeMapping = methodSymbol.ExtractGenericFactoryMapping(); + var typeParameterCount = methodSymbol.TypeParameters.Length; + + return new FactoryMethodData( + path, + hasServiceProvider, + hasKey, + returnTypeName, + additionalParameters?.ToImmutableEquatableArray() ?? [], + genericTypeMapping, + typeParameterCount); + } + + /// + /// Extracts [IocGenericFactory] attribute from the method symbol and builds the type mapping. + /// + public GenericFactoryTypeMapping? ExtractGenericFactoryMapping() + { + // Only applicable to generic methods + if(methodSymbol.TypeParameters.Length == 0) + { + return null; + } + + // Find [IocGenericFactory] attribute + AttributeData? genericFactoryAttr = null; + foreach(var attr in methodSymbol.GetAttributes()) + { + var attrClass = attr.AttributeClass; + if(attrClass is null) + continue; + + var fullName = attrClass.ToDisplayString(); + if(fullName == Constants.IocGenericFactoryAttributeFullName) + { + genericFactoryAttr = attr; + break; + } + } + + if(genericFactoryAttr is null) + { + return null; + } + + // Extract GenericTypeMap array from constructor argument + // [IocGenericFactory(typeof(IRequestHandler>), typeof(int))] + // - First type: service type template with placeholders + // - Following types: map to factory method type parameters in order + if(genericFactoryAttr.ConstructorArguments.Length == 0) + { + return null; + } + + var firstArg = genericFactoryAttr.ConstructorArguments[0]; + if(firstArg.Kind != TypedConstantKind.Array || firstArg.Values.IsDefaultOrEmpty) + { + return null; + } + + var typeArray = firstArg.Values; + if(typeArray.Length < 2) + { + return null; // Need at least service type template and one placeholder mapping + } + + // First type is the service type template + if(typeArray[0].Value is not INamedTypeSymbol serviceTypeTemplate) + { + return null; + } + + var serviceTypeTemplateData = serviceTypeTemplate.GetTypeData(); + + // Build placeholder to type parameter index map + // Following types (index 1, 2, ...) map to factory method's type parameters (index 0, 1, ...) + var placeholderMap = new Dictionary(StringComparer.Ordinal); + var expectedPlaceholderCount = typeArray.Length - 1; + for(int i = 1; i < typeArray.Length; i++) + { + if(typeArray[i].Value is ITypeSymbol placeholderType) + { + var placeholderTypeName = placeholderType.FullyQualifiedName; + + // If the same placeholder type is used multiple times, the mapping is invalid + // because we cannot distinguish which type argument maps to which type parameter + if(placeholderMap.ContainsKey(placeholderTypeName)) + { + return null; + } + + // Map placeholder type to factory method's type parameter index (0-based) + placeholderMap[placeholderTypeName] = i - 1; + } + } + + // All placeholder types must be unique and present + if(placeholderMap.Count != expectedPlaceholderCount) + { + return null; + } + + return new GenericFactoryTypeMapping( + serviceTypeTemplateData, + placeholderMap.ToImmutableEquatableDictionary()); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.InjectionMembers.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.InjectionMembers.cs new file mode 100644 index 0000000..a120b3c --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.InjectionMembers.cs @@ -0,0 +1,64 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(INamedTypeSymbol typeSymbol) + { + public bool IsInject => + typeSymbol.Name is "IocInjectAttribute" or "InjectAttribute"; + + /// + /// Enumerates members (properties, fields, methods) marked with IocInjectAttribute/InjectAttribute. + /// This is a shared method used by both Analyzer (ServiceInfo) and Generator (RegistrationData). + /// + /// + /// The method filters members based on: + /// - Non-static members only + /// - Properties with a setter + /// - Non-readonly fields + /// - Ordinary methods that return (sync) or non-generic + /// (async, when AsyncMethodInject + /// feature is enabled), and are not generic + /// + /// + /// An enumerable of tuples containing the member symbol and its inject attribute. + /// Analyzer can use ISymbol directly; Generator can convert to InjectionMemberData. + /// + public IEnumerable<(ISymbol Member, AttributeData InjectAttribute)> GetInjectedMembers() + { + // For unbound generic types (e.g., LoggingDecorator<,>), we need to use OriginalDefinition + // to get the actual member declarations with their attributes + var typeToInspect = typeSymbol.IsUnboundGenericType ? typeSymbol.OriginalDefinition : typeSymbol; + + foreach(var member in typeToInspect.GetMembers()) + { + // Skip static members + if(member.IsStatic) + continue; + + // Check if the member has IocInjectAttribute/InjectAttribute (by name only) + var injectAttribute = member.GetAttributes() + .FirstOrDefault(static attr => attr.AttributeClass?.IsInject == true); + + if(injectAttribute is null) + continue; + + // Validate member is injectable based on type + var isInjectable = member switch + { + IPropertySymbol property => property.SetMethod is not null, + IFieldSymbol field => !field.IsReadOnly, + IMethodSymbol method => method.MethodKind == MethodKind.Ordinary + && (method.ReturnsVoid || RoslynExtensions.IsNonGenericTaskReturnType(method)) + && !method.IsGenericMethod, + _ => false + }; + + if(isInjectable) + { + yield return (member, injectAttribute); + } + } + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.KeyInfo.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.KeyInfo.cs new file mode 100644 index 0000000..07a46c9 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.KeyInfo.cs @@ -0,0 +1,153 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(AttributeData attribute) + { + /// + /// Gets all key-related information from the attribute in a single pass: + /// key string, key type, and key value type symbol (with optional nameof() resolution). + /// + /// Optional semantic model to resolve nameof() expression types and full access paths. + /// + /// A tuple containing: + /// - Key: The key string, or null if no key is specified. + /// - KeyType: The key type (0 = Value, 1 = Csharp). + /// - KeyValueTypeSymbol: The type symbol of the key value, or null when the type cannot be determined. + /// + public (string? Key, int KeyType, ITypeSymbol? KeyValueTypeSymbol) GetKeyInfo(SemanticModel? semanticModel = null) + { + var keyType = attribute.GetNamedArgument("KeyType", 0); + var isCsharpKeyType = keyType == 1; + + // First, check if key is passed as a constructor argument (e.g., InjectAttribute(object key)) + if(attribute.ConstructorArguments.Length > 0) + { + var ctorArg = attribute.ConstructorArguments[0]; + // Skip if the first argument is a type, lifetime enum, or array (e.g., IoCRegisterDefaultsAttribute) + if(ctorArg.Type?.Name != nameof(ServiceLifetime) + && ctorArg.Kind != TypedConstantKind.Type + && ctorArg.Kind != TypedConstantKind.Array + && !ctorArg.IsNull) + { + if(isCsharpKeyType) + { + // Try to get original syntax for nameof() expressions with full access path resolution + var key = attribute.TryGetNameofFromConstructorArg(0, semanticModel) + ?? ctorArg.Value?.ToString(); + var keyValueType = TryResolveNameofTypeFromConstructorArg(attribute, 0, semanticModel); + return (key, keyType, keyValueType); + } + + // Value key: treat the primitive constant as CSharp code + return (ctorArg.GetPrimitiveConstantString(), 1, ctorArg.Type); + } + } + + // Fall back to named argument + foreach(var namedArg in attribute.NamedArguments) + { + if(namedArg.Key != "Key") + continue; + + if(namedArg.Value.IsNull) + return (null, keyType, null); + + if(isCsharpKeyType) + { + // Try to get original syntax for nameof() expressions with full access path resolution + var key = attribute.TryGetNameof("Key", semanticModel) + ?? namedArg.Value.Value?.ToString(); + var keyValueType = TryResolveNameofTypeFromNamedArg(attribute, "Key", semanticModel); + return (key, keyType, keyValueType); + } + + // Value key: treat the primitive constant as CSharp code + return (namedArg.Value.GetPrimitiveConstantString(), 1, namedArg.Value.Type); + } + + return (null, keyType, null); + } + + /// + /// Tries to resolve the type of a nameof() expression in a constructor argument. + /// Returns null if the argument is not a nameof() expression or cannot be resolved. + /// + private static ITypeSymbol? TryResolveNameofTypeFromConstructorArg(AttributeData attr, int argumentIndex, SemanticModel? semanticModel) + { + if(semanticModel is null) + return null; + + var syntaxReference = attr.ApplicationSyntaxReference; + if(syntaxReference?.GetSyntax() is not AttributeSyntax attributeSyntax) + return null; + + var argumentList = attributeSyntax.ArgumentList; + if(argumentList is null || argumentList.Arguments.Count <= argumentIndex) + return null; + + var argument = argumentList.Arguments[argumentIndex]; + if(argument.NameEquals is not null) + return null; + + return ResolveNameofExpressionType(argument.Expression, semanticModel); + } + + /// + /// Tries to resolve the type of a nameof() expression in a named argument. + /// Returns null if the argument is not a nameof() expression or cannot be resolved. + /// + private static ITypeSymbol? TryResolveNameofTypeFromNamedArg(AttributeData attr, string argumentName, SemanticModel? semanticModel) + { + if(semanticModel is null) + return null; + + var syntaxReference = attr.ApplicationSyntaxReference; + if(syntaxReference?.GetSyntax() is not AttributeSyntax attributeSyntax) + return null; + + var argumentList = attributeSyntax.ArgumentList; + if(argumentList is null) + return null; + + foreach(var argument in argumentList.Arguments) + { + if(argument.NameEquals?.Name.Identifier.Text == argumentName) + { + return ResolveNameofExpressionType(argument.Expression, semanticModel); + } + } + + return null; + } + + /// + /// If the expression is a nameof() invocation, resolves the referenced symbol's type. + /// Returns null for non-nameof expressions. + /// + private static ITypeSymbol? ResolveNameofExpressionType(ExpressionSyntax expression, SemanticModel semanticModel) + { + if(expression is not InvocationExpressionSyntax invocation + || invocation.Expression is not IdentifierNameSyntax identifierName + || identifierName.Identifier.Text != "nameof" + || invocation.ArgumentList.Arguments.Count != 1) + { + return null; + } + + var nameofArgument = invocation.ArgumentList.Arguments[0].Expression; + var symbolInfo = semanticModel.GetSymbolInfo(nameofArgument); + var symbol = symbolInfo.Symbol ?? symbolInfo.CandidateSymbols.FirstOrDefault(); + + return symbol switch + { + IFieldSymbol field => field.Type, + IPropertySymbol property => property.Type, + IMethodSymbol method => method.ReturnType, + ILocalSymbol local => local.Type, + IParameterSymbol param => param.Type, + _ => null, + }; + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.Parameters.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.Parameters.cs new file mode 100644 index 0000000..a622d5a --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.Parameters.cs @@ -0,0 +1,72 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + extension(IParameterSymbol param) + { + /// + /// Gets the service key, injection attribute info, and [ServiceKey]/[FromKeyedServices] attribute from a parameter. + /// [FromKeyedServices] takes precedence over [Inject] for service key resolution. + /// HasInjectAttribute is only true for [Inject] attribute (not [FromKeyedServices], which MS.DI handles automatically). + /// HasServiceKeyAttribute indicates the parameter is marked with [ServiceKey] from Microsoft.Extensions.DependencyInjection. + /// HasFromKeyedServicesAttribute indicates the parameter is marked with [FromKeyedServices] from Microsoft.Extensions.DependencyInjection. + /// + /// A tuple containing the service key (if any), whether the parameter has [Inject] attribute, [ServiceKey] attribute, and [FromKeyedServices] attribute. + public (string? ServiceKey, bool HasInjectAttribute, bool HasServiceKeyAttribute, bool HasFromKeyedServicesAttribute) GetServiceKeyAndAttributeInfo(SemanticModel? semanticModel = null) + { + string? serviceKey = null; + bool hasInjectAttribute = false; + bool hasServiceKeyAttribute = false; + bool hasFromKeyedServicesAttribute = false; + + foreach(var attribute in param.GetAttributes()) + { + var attrClass = attribute.AttributeClass; + if(attrClass is null) + continue; + + // Check for Microsoft.Extensions.DependencyInjection.ServiceKeyAttribute + if(attrClass.Name == "ServiceKeyAttribute" + && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") + { + hasServiceKeyAttribute = true; + continue; + } + + // Check for Microsoft.Extensions.DependencyInjection.FromKeyedServicesAttribute (higher priority for key) + // Note: [FromKeyedServices] is handled by MS.DI automatically, so we don't set hasInjectAttribute + if(attrClass.Name == "FromKeyedServicesAttribute" + && attrClass.ContainingNamespace?.ToDisplayString() == "Microsoft.Extensions.DependencyInjection") + { + hasFromKeyedServicesAttribute = true; + // The key is the first constructor argument + if(attribute.ConstructorArguments.Length > 0) + { + var keyArg = attribute.ConstructorArguments[0]; + if(!keyArg.IsNull && keyArg.Value is not null) + { + serviceKey = keyArg.GetPrimitiveConstantString(); + } + } + + // [FromKeyedServices] found, but continue to check for [Inject] as well + continue; + } + + // Check for IocInjectAttribute/InjectAttribute (by name only, to support third-party attributes) + if(attrClass.IsInject) + { + hasInjectAttribute = true; + // Only use [Inject] key if no [FromKeyedServices] key was found + if(serviceKey is null) + { + var (key, _, _) = attribute.GetKeyInfo(semanticModel); + serviceKey = key; + } + } + } + + return (serviceKey, hasInjectAttribute, hasServiceKeyAttribute, hasFromKeyedServicesAttribute); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.TypeData.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.TypeData.cs new file mode 100644 index 0000000..337754f --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformExtensions.TypeData.cs @@ -0,0 +1,426 @@ +namespace SourceGen.Ioc.SourceGenerator.Models; + +internal static partial class TransformExtensions +{ + private static string GetNameWithoutGeneric(string typeName) + { + int angleIndex = typeName.IndexOf('<'); + return angleIndex > 0 ? typeName[..angleIndex] : typeName; + } + + private static WrapperKind NormalizeGeneratorWrapperKind(WrapperKind kind) + => kind is WrapperKind.ValueTask ? WrapperKind.None : kind; + + extension(ITypeSymbol typeSymbol) + { + public TypeData GetTypeData( + bool extractConstructorParams = false, + bool extractHierarchy = false, + HashSet? visited = null) + { + if(typeSymbol is INamedTypeSymbol namedTypeSymbol) + return namedTypeSymbol.GetTypeData(extractConstructorParams, extractHierarchy, visited); + + // Handle array types specially - extract element type information + if(typeSymbol is IArrayTypeSymbol arrayTypeSymbol) + return arrayTypeSymbol.GetTypeData(extractConstructorParams, extractHierarchy, visited); + + var name = typeSymbol.FullyQualifiedName; + return TypeData.CreateSimple(name); + } + } + + extension(INamedTypeSymbol typeSymbol) + { + /// + /// Gets the type data for this type symbol. + /// + /// Whether to extract constructor parameters recursively. + /// Whether to extract all interfaces and base classes. + /// Set of visited types to prevent infinite recursion during constructor parameter extraction. + /// Optional semantic model for resolving nameof() expressions. Only used for top-level extraction, not passed to recursive calls. + /// Whether to extract injection members (properties, fields, methods with [IocInject] attributes). Used for decorators. + public TypeData GetTypeData( + bool extractConstructorParams = false, + bool extractHierarchy = false, + HashSet? visited = null, + SemanticModel? semanticModel = null, + bool extractInjectionMembers = false) + { + visited = extractConstructorParams + ? (visited ?? new(SymbolEqualityComparer.Default)) + : null; + + // Build type name - for unbound generics, use actual type parameter names + var typeName = typeSymbol.BuildTypeName(); + + // Extract type parameters with full constraints + ImmutableEquatableArray? typeParameters = null; + if(typeSymbol.IsGenericType && typeSymbol.TypeArguments.Length > 0) + { + typeParameters = typeSymbol.ExtractTypeParameters(extractConstraints: true, depth: 0); + } + + ImmutableEquatableArray? constructorParams = null; + bool hasInjectConstructor = false; + if(extractConstructorParams && visited is not null) + { + // Pass semanticModel only for top-level extraction + // Recursive calls from within ExtractConstructorParametersWithInfo do not receive semanticModel + // to avoid cross-compilation-unit issues and stack overflow + (constructorParams, hasInjectConstructor) = typeSymbol.ExtractConstructorParametersWithInfo(visited, semanticModel); + } + + // Extract injection members if requested (for decorators) + ImmutableEquatableArray? injectionMembers = null; + if(extractInjectionMembers) + { + injectionMembers = typeSymbol.ExtractInjectionMembersForDecorator(semanticModel); + if(injectionMembers.Length == 0) + { + injectionMembers = null; + } + } + + // Extract hierarchy (interfaces and base classes) if requested + ImmutableEquatableArray? allInterfaces = null; + ImmutableEquatableArray? allBaseClasses = null; + if(extractHierarchy) + { + allInterfaces = typeSymbol.GetAllInterfaces(); + allBaseClasses = typeSymbol.GetAllBaseClasses(); + } + + // Error types (e.g., types from other source generators not yet resolved) + // should be treated as simple types, not open generics + if(typeSymbol.TypeKind == TypeKind.Error) + { + return TypeData.CreateSimple( + typeName, + constructorParams, + hasInjectConstructor, + injectionMembers, + allInterfaces, + allBaseClasses); + } + + // Check if this is a wrapper type (collection or non-collection) for DI + var nameWithoutGeneric = GetNameWithoutGeneric(typeName); + var wrapperKind = typeSymbol.TryGetWrapperInfo(out var wrapperInfo) + ? NormalizeGeneratorWrapperKind(wrapperInfo.Kind) + : WrapperKind.None; + + if(wrapperKind is not WrapperKind.None && typeSymbol.TypeArguments.Length > 0) + { + foreach(var typeArgument in typeSymbol.TypeArguments.OfType()) + { + if(!typeArgument.TryGetWrapperInfo(out var childWrapperInfo)) + { + continue; + } + + var childWrapperKind = NormalizeGeneratorWrapperKind(childWrapperInfo.Kind); + if(childWrapperKind is WrapperKind.None) + { + continue; + } + + // Keep generator behavior for collection wrappers: do not downgrade + // IEnumerable> / IEnumerable> at this classification stage. + // Collection nesting limitations are enforced by wrapper resolver generation. + if(IsUnsupportedWrapperNesting(wrapperKind, childWrapperKind, isAfterCollection: false)) + { + wrapperKind = WrapperKind.None; + break; + } + } + } + + if(wrapperKind is not WrapperKind.None) + { + return TypeData.CreateWrapper( + typeName, + nameWithoutGeneric, + typeSymbol.ContainsGenericParameters, + typeSymbol.Arity, + wrapperKind, + typeSymbol.IsNestedOpenGeneric, + typeParameters, + constructorParams, + hasInjectConstructor, + injectionMembers, + allInterfaces, + allBaseClasses); + } + + if(typeSymbol.ContainsGenericParameters || typeSymbol.Arity > 0 || typeParameters is { Length: > 0 }) + { + return TypeData.CreateGeneric( + typeName, + nameWithoutGeneric, + typeSymbol.ContainsGenericParameters, + typeSymbol.Arity, + typeSymbol.IsNestedOpenGeneric, + typeParameters, + constructorParams, + hasInjectConstructor, + injectionMembers, + allInterfaces, + allBaseClasses); + } + + return TypeData.CreateSimple( + typeName, + constructorParams, + hasInjectConstructor, + injectionMembers, + allInterfaces, + allBaseClasses); + } + + /// + /// Builds the fully qualified type name for this type symbol. + /// For unbound generic types, uses actual type parameter names instead of empty placeholders. + /// + public string BuildTypeName() + { + // For unbound generic types (e.g., typeof(Handler<,>)), we need to get the + // type parameter names from TypeParameters, not from FullyQualifiedName + // FullyQualifiedName returns "global::Ns.Handler<,>" but we need "global::Ns.Handler" + if(typeSymbol.IsUnboundGenericType && typeSymbol.TypeParametersSource.Length > 0) + { + var nameWithoutGeneric = GetNameWithoutGeneric(typeSymbol.FullyQualifiedName); + var typeParamNames = typeSymbol.TypeParametersSource.Select(tp => tp.Name); + return $"{nameWithoutGeneric}<{string.Join(", ", typeParamNames)}>"; + } + + return typeSymbol.FullyQualifiedName; + } + + /// + /// Core implementation for extracting type parameters. + /// + /// Whether to extract constraint types for each type parameter. + /// Current recursion depth to prevent infinite recursion. + /// An immutable array of type parameters with their resolved types. + public ImmutableEquatableArray ExtractTypeParameters(bool extractConstraints, int depth) + { + const int MaxDepth = 10; // Prevent infinite recursion for pathological cases + + var typeParams = typeSymbol.TypeParametersSource; + if(typeParams.Length == 0 || depth >= MaxDepth) + { + return []; + } + + var typeArgs = typeSymbol.TypeArguments; + List parameters = new(typeParams.Length); + + for(int i = 0; i < typeParams.Length; i++) + { + var typeParam = typeParams[i]; + var typeArg = i < typeArgs.Length ? typeArgs[i] : null; + + // Create TypeData for the type argument + var (typeData, allInterfaces) = typeParam.CreateTypeDataForTypeArg(typeArg, depth); + + // Add interfaces if extracted + if(allInterfaces is { Length: > 0 }) + { + typeData = typeData with { AllInterfaces = allInterfaces }; + } + + // Extract constraints only when requested (to avoid recursion in basic scenarios) + ImmutableEquatableArray? constraintTypes = null; + if(extractConstraints) + { + constraintTypes = typeParam.ConstraintTypes + .Select(ct => ct is INamedTypeSymbol namedCt + ? namedCt.CreateBasicTypeData(depth + 1) + : ct.TypeKind == TypeKind.TypeParameter + ? TypeData.CreateTypeParameter(ct.FullyQualifiedName) + : TypeData.CreateGeneric( + ct.FullyQualifiedName, + GetNameWithoutGeneric(ct.FullyQualifiedName), + ct.ContainsGenericParameters, + 0, + false)) + .ToImmutableEquatableArray(); + } + + parameters.Add(new TypeParameter( + typeParam.Name, + typeData, + constraintTypes, + typeParam.HasValueTypeConstraint, + typeParam.HasReferenceTypeConstraint, + typeParam.HasUnmanagedTypeConstraint, + typeParam.HasNotNullConstraint, + typeParam.HasConstructorConstraint)); + } + + return parameters.ToImmutableEquatableArray(); + } + + /// + /// Gets all interfaces implemented by a type. + /// Creates basic TypeData without recursive type parameter extraction to avoid circular dependencies. + /// + public ImmutableEquatableArray GetAllInterfaces() => + typeSymbol.AllInterfaces.Select(CreateBasicTypeData).ToImmutableEquatableArray(); + + /// + /// Gets all base classes of a type, excluding System.Object. + /// Creates basic TypeData without recursive type parameter extraction to avoid circular dependencies. + /// + public ImmutableEquatableArray GetAllBaseClasses() + { + List result = []; + var baseType = typeSymbol.BaseType; + while(baseType != null && baseType.SpecialType != SpecialType.System_Object) + { + result.Add(baseType.CreateBasicTypeData()); + baseType = baseType.BaseType; + } + + return result.ToImmutableEquatableArray(); + } + + /// + /// Creates a basic TypeData with type parameters extracted recursively. + /// Does not extract constraint types to avoid circular dependencies. + /// + public TypeData CreateBasicTypeData(int depth = 0) + { + var typeName = typeSymbol.FullyQualifiedName; + var nameWithoutGeneric = GetNameWithoutGeneric(typeName); + + // Extract type parameters without constraints (to avoid recursion) + ImmutableEquatableArray? typeParameters = null; + if(typeSymbol.IsGenericType && typeSymbol.TypeArguments.Length > 0) + { + typeParameters = typeSymbol.ExtractTypeParameters(extractConstraints: false, depth); + } + + // Check if this is a wrapper type (collection or non-collection) + var wrapperKind = typeSymbol.TryGetWrapperInfo(out var wrapperInfo) + ? NormalizeGeneratorWrapperKind(wrapperInfo.Kind) + : WrapperKind.None; + + if(wrapperKind is not WrapperKind.None) + { + return TypeData.CreateWrapper( + typeName, + nameWithoutGeneric, + typeSymbol.ContainsGenericParameters, + typeSymbol.Arity, + wrapperKind, + typeSymbol.IsNestedOpenGeneric, + typeParameters); + } + + if(typeSymbol.ContainsGenericParameters || typeSymbol.Arity > 0 || typeParameters is { Length: > 0 }) + { + return TypeData.CreateGeneric( + typeName, + nameWithoutGeneric, + typeSymbol.ContainsGenericParameters, + typeSymbol.Arity, + typeSymbol.IsNestedOpenGeneric, + typeParameters); + } + + return TypeData.CreateSimple(typeName); + } + } + + extension(ITypeParameterSymbol typeParam) + { + /// + /// Creates TypeData for a type argument, with optional interface extraction. + /// + public (TypeData TypeData, ImmutableEquatableArray? AllInterfaces) CreateTypeDataForTypeArg( + ITypeSymbol? typeArg, + int depth) + { + if(typeArg is INamedTypeSymbol namedArg && typeArg.TypeKind != TypeKind.TypeParameter) + { + // For concrete types, recursively extract type parameters and interfaces + var typeData = namedArg.CreateBasicTypeData(depth + 1); + var allInterfaces = namedArg.AllInterfaces.Length > 0 + ? namedArg.AllInterfaces.Select(CreateInterfaceTypeData).ToImmutableEquatableArray() + : null; + return (typeData, allInterfaces); + } + + if(typeArg is not null) + { + var argName = typeArg.FullyQualifiedName; + TypeData typeData = typeArg.TypeKind == TypeKind.TypeParameter + ? TypeData.CreateTypeParameter(argName) + : TypeData.CreateGeneric( + argName, + GetNameWithoutGeneric(argName), + typeArg.ContainsGenericParameters, + 0); + return (typeData, null); + } + + // No type argument available, this is a type parameter placeholder + return (TypeData.CreateTypeParameter(typeParam.Name), null); + + // Creates a simple TypeData for an interface type. + static TypeData CreateInterfaceTypeData(INamedTypeSymbol iface) => + iface.IsGenericType || iface.Arity > 0 + ? TypeData.CreateGeneric( + iface.FullyQualifiedName, + GetNameWithoutGeneric(iface.FullyQualifiedName), + iface.ContainsGenericParameters, + iface.Arity, + false) + : TypeData.CreateSimple(iface.FullyQualifiedName); + } + } + + extension(IArrayTypeSymbol arrayTypeSymbol) + { + public TypeData GetTypeData( + bool extractConstructorParams = false, + bool extractHierarchy = false, + HashSet? visited = null) + { + var elementType = arrayTypeSymbol.ElementType; + var typeName = arrayTypeSymbol.FullyQualifiedName; + + // For arrays, create TypeData with element type as a pseudo-TypeParameter + // This allows TryGetArrayElementType to extract the element type + ImmutableEquatableArray typeParameters; + if(elementType is INamedTypeSymbol namedElementType) + { + var elementTypeData = namedElementType.GetTypeData(extractConstructorParams, extractHierarchy, visited); + typeParameters = [new TypeParameter("T", elementTypeData)]; + } + else + { + var elementTypeName = elementType.FullyQualifiedName; + TypeData elementTypeData = elementType.TypeKind == TypeKind.TypeParameter + ? TypeData.CreateTypeParameter(elementTypeName) + : TypeData.CreateGeneric( + elementTypeName, + GetNameWithoutGeneric(elementTypeName), + elementType.ContainsGenericParameters, + 0); + typeParameters = [new TypeParameter("T", elementTypeData)]; + } + + return TypeData.CreateWrapper( + typeName, + typeName, // For arrays, use full name as NameWithoutGeneric + elementType.ContainsGenericParameters, + GenericArity: 1, // Arrays have one "type parameter" (the element type) + WrapperKind.Array, // Arrays are collections + IsNestedOpenGeneric: false, + TypeParameters: typeParameters); + } + } +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformImportModule.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformImportModule.cs similarity index 93% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformImportModule.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformImportModule.cs index 040649e..69d6b5c 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformImportModule.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformImportModule.cs @@ -12,11 +12,8 @@ private static IEnumerable TransformImportModule(GeneratorAt { ct.ThrowIfCancellationRequested(); - // Get the ModuleType from the attribute - if(attr.ConstructorArguments.Length == 0) - continue; - - if(attr.ConstructorArguments[0].Value is not INamedTypeSymbol moduleType) + var moduleType = attr.GetImportedModuleType(); + if(moduleType is null) continue; // Get the assembly containing the module type @@ -44,11 +41,8 @@ private static IEnumerable TransformImportModuleGeneric(Gene { ct.ThrowIfCancellationRequested(); - var attrClass = attr.AttributeClass; - if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) - continue; - - if(attrClass.TypeArguments[0] is not INamedTypeSymbol moduleType) + var moduleType = attr.GetImportedModuleType(); + if(moduleType is null) continue; // Get the assembly containing the module type @@ -105,11 +99,7 @@ private static IEnumerable ExtractOpenGenericEntriesFromModule if(fullName == Constants.IocImportModuleAttributeFullName) { - if(attr.ConstructorArguments.Length > 0 && - attr.ConstructorArguments[0].Value is INamedTypeSymbol modType) - { - importedModuleType = modType; - } + importedModuleType = attr.GetImportedModuleType(); } else if(attrClass.IsGenericType) { @@ -120,12 +110,8 @@ private static IEnumerable ExtractOpenGenericEntriesFromModule ? metadataName : $"{metadataNamespace}.{metadataName}"; - if(originalFullName == Constants.IocImportModuleAttributeFullName_T1 && - attrClass.TypeArguments.Length > 0 && - attrClass.TypeArguments[0] is INamedTypeSymbol genericModType) - { - importedModuleType = genericModType; - } + if(originalFullName == Constants.IocImportModuleAttributeFullName_T1) + importedModuleType = attr.GetImportedModuleType(); } if(importedModuleType is not null) diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformRegister.Extraction.cs similarity index 89% rename from src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs rename to src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformRegister.Extraction.cs index 1b4e5f2..f0e98f3 100644 --- a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Generator/TransformRegister.cs +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformRegister.Extraction.cs @@ -1,73 +1,7 @@ -ο»Ώnamespace SourceGen.Ioc; +namespace SourceGen.Ioc; partial class IocSourceGenerator { - private static RegistrationData? TransformRegister(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) - { - if(ctx.TargetSymbol is not INamedTypeSymbol typeSymbol) - return null; - - var attributeData = ctx.Attributes.FirstOrDefault(); - if(attributeData == null) - return null; - - return ExtractRegistrationData(typeSymbol, attributeData, ctx.SemanticModel); - } - - /// - /// Transforms generic IocRegisterAttribute (e.g., IocRegisterAttribute<T>) to extract registration data. - /// The service types are specified via type parameters instead of constructor arguments. - /// - private static RegistrationData? TransformRegisterGeneric(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) - { - if(ctx.TargetSymbol is not INamedTypeSymbol typeSymbol) - return null; - - var attributeData = ctx.Attributes.FirstOrDefault(); - if(attributeData == null) - return null; - - return ExtractRegistrationDataFromGenericAttribute(typeSymbol, attributeData, ctx.SemanticModel); - } - - private static IEnumerable TransformRegisterFor(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) - { - foreach(var attr in ctx.Attributes) - { - if(attr.ConstructorArguments.Length == 0) - continue; - if(attr.ConstructorArguments[0].Value is not INamedTypeSymbol targetType) - continue; - - var data = ExtractRegistrationData(targetType, attr, ctx.SemanticModel); - - yield return data; - } - } - - /// - /// Transforms generic IoCRegisterForAttribute (IoCRegisterForAttribute<T>) to extract registration data. - /// The target type is specified via type parameter instead of constructor argument. - /// - private static IEnumerable TransformRegisterForGeneric(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) - { - foreach(var attr in ctx.Attributes) - { - var attrClass = attr.AttributeClass; - if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) - continue; - - if(attrClass.TypeArguments[0] is not INamedTypeSymbol targetType) - continue; - - // Use ExtractRegistrationData because IoCRegisterForAttribute uses ServiceTypes named argument, - // not the generic type parameter, for specifying service types. - var data = ExtractRegistrationData(targetType, attr, ctx.SemanticModel); - - yield return data; - } - } - private static RegistrationData ExtractRegistrationData(INamedTypeSymbol typeSymbol, AttributeData attributeData, SemanticModel? semanticModel = null) { // Pass semanticModel to GetTypeData for proper nameof() expression resolution in constructor parameter keys @@ -645,5 +579,4 @@ private static bool IsNameofInvocation(ExpressionSyntax expression) IMethodSymbol method => CreateMethodInjection(method, key, semanticModel), _ => null }; - -} +} \ No newline at end of file diff --git a/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformRegister.cs b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformRegister.cs new file mode 100644 index 0000000..fd81eb4 --- /dev/null +++ b/src/Ioc/src/SourceGen.Ioc.SourceGenerator/Transforms/TransformRegister.cs @@ -0,0 +1,76 @@ +ο»Ώnamespace SourceGen.Ioc; + +partial class IocSourceGenerator +{ + private static RegistrationData? TransformRegister(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) + { + ct.ThrowIfCancellationRequested(); + + if(ctx.TargetSymbol is not INamedTypeSymbol typeSymbol) + return null; + + if(ctx.Attributes.Length == 0) + return null; + + return ExtractRegistrationData(typeSymbol, ctx.Attributes[0], ctx.SemanticModel); + } + + /// + /// Transforms generic IocRegisterAttribute (e.g., IocRegisterAttribute<T>) to extract registration data. + /// The service types are specified via type parameters instead of constructor arguments. + /// + private static RegistrationData? TransformRegisterGeneric(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) + { + ct.ThrowIfCancellationRequested(); + + if(ctx.TargetSymbol is not INamedTypeSymbol typeSymbol) + return null; + + if(ctx.Attributes.Length == 0) + return null; + + return ExtractRegistrationDataFromGenericAttribute(typeSymbol, ctx.Attributes[0], ctx.SemanticModel); + } + + private static IEnumerable TransformRegisterFor(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) + { + foreach(var attr in ctx.Attributes) + { + ct.ThrowIfCancellationRequested(); + + if(attr.ConstructorArguments.Length == 0) + continue; + if(attr.ConstructorArguments[0].Value is not INamedTypeSymbol targetType) + continue; + + var data = ExtractRegistrationData(targetType, attr, ctx.SemanticModel); + + yield return data; + } + } + + /// + /// Transforms generic IoCRegisterForAttribute (IoCRegisterForAttribute<T>) to extract registration data. + /// The target type is specified via type parameter instead of constructor argument. + /// + private static IEnumerable TransformRegisterForGeneric(GeneratorAttributeSyntaxContext ctx, CancellationToken ct) + { + foreach(var attr in ctx.Attributes) + { + ct.ThrowIfCancellationRequested(); + + var attrClass = attr.AttributeClass; + if(attrClass?.IsGenericType != true || attrClass.TypeArguments.Length == 0) + continue; + + if(attrClass.TypeArguments[0] is not INamedTypeSymbol targetType) + continue; + + // Use ExtractRegistrationData because IoCRegisterForAttribute uses ServiceTypes named argument, + // not the generic type parameter, for specifying service types. + var data = ExtractRegistrationData(targetType, attr, ctx.SemanticModel); + + yield return data; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs index 70abe07..c9c360f 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC021Tests.cs @@ -518,6 +518,74 @@ public partial class TestContainer await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); } + [Test] + public async Task SGIOC021_PartialAccessor_EnumerableLazy_NoDiagnostic() + { + const string source = """ + using System; + using System.Collections.Generic; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IEnumerable> GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC021_PartialAccessor_EnumerableFunc_NoDiagnostic() + { + const string source = """ + using System; + using System.Collections.Generic; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IService { } + + [IocRegister(ServiceTypes = [typeof(IService)])] + public class TestService : IService { } + + [IocContainer(IntegrateServiceProvider = false)] + public partial class TestContainer + { + public partial IEnumerable> GetServices(); + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC021")).Count().IsEqualTo(0); + } + [Test] public async Task SGIOC021_AlwaysResolvableType_NoDiagnostic() { diff --git a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs index 768a2bf..3851a1f 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/Analyzer/SGIOC030Tests.cs @@ -85,6 +85,134 @@ public Consumer(Task service) { } await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030")).Count().IsEqualTo(0); } + [Test] + public async Task SGIOC030_ConstructorServiceTypes_FallbackRecognized_NoDiagnosticWhenSyncRegistrationExists() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IMyService)])] + public class AsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceLifetime.Singleton, typeof(IMyService))] + public class SyncService : IMyService { } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer(IMyService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + + await Assert.That(SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030")).Count().IsEqualTo(0); + } + + [Test] + public async Task SGIOC030_ServiceTypesNamedArgument_PrecedesConstructorParams() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IConstructorService { } + public interface INamedService { } + public interface IConsumer { } + + [IocRegister(ServiceTypes = [typeof(IConstructorService)])] + public class AsyncService : IConstructorService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceLifetime.Singleton, typeof(IConstructorService), ServiceTypes = [typeof(INamedService)])] + public class SyncService : IConstructorService, INamedService { } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer(IConstructorService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IConstructorService"); + } + + [Test] + public async Task SGIOC030_GenericServiceType_IsRecognized() + { + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyService { } + public interface IConsumer { } + + [IocRegister(ServiceLifetime.Singleton)] + public class AsyncService : IMyService + { + [IocInject] + public Task InitializeAsync() => Task.CompletedTask; + } + + [IocRegister(ServiceTypes = [typeof(IConsumer)])] + public class Consumer : IConsumer + { + public Consumer(IMyService service) { } + } + """; + + var analyzerConfigOptions = new Dictionary + { + ["build_property.SourceGenIocFeatures"] = "Register,Container,MethodInject,AsyncMethodInject" + }; + + var diagnostics = await SourceGeneratorTestHelper.RunAnalyzerAsync( + source, + analyzerConfigOptions: analyzerConfigOptions); + var sgioc030 = SourceGeneratorTestHelper.GetDiagnosticsById(diagnostics, "SGIOC030").ToList(); + + await Assert.That(sgioc030).Count().IsEqualTo(1); + await Assert.That(sgioc030[0].GetMessage()).Contains("service").And.Contains("IMyService"); + } + [Test] public async Task SGIOC030_PropertyInjectionRequestsSyncTypeForAsyncInitService_ReportsDiagnostic() { diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ActivatorContainerTests.Container_WithRegisteredComponent_GeneratesComponentWithPropertyInjection.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ActivatorContainerTests.Container_WithRegisteredComponent_GeneratesComponentWithPropertyInjection.verified.txt index 42c6ebc..0c51564 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ActivatorContainerTests.Container_WithRegisteredComponent_GeneratesComponentWithPropertyInjection.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ActivatorContainerTests.Container_WithRegisteredComponent_GeneratesComponentWithPropertyInjection.verified.txt @@ -37,7 +37,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_DataService = GetTestNamespace_DataService(); + GetTestNamespace_DataService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_ScopedAsyncInit_EagerNone_DoesNotFireInScopeConstructor.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_ScopedAsyncInit_EagerNone_DoesNotFireInScopeConstructor.verified.txt new file mode 100644 index 0000000..35e5808 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_ScopedAsyncInit_EagerNone_DoesNotFireInScopeConstructor.verified.txt @@ -0,0 +1,326 @@ +ο»Ώ// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_AsyncScopedService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_AsyncScopedServiceAsync() + { + if(_testNamespace_AsyncScopedService is not null) + return await _testNamespace_AsyncScopedService; + + _testNamespace_AsyncScopedService = CreateTestNamespace_AsyncScopedServiceAsync(); + return await _testNamespace_AsyncScopedService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_AsyncScopedServiceAsync() + { + var instance = new global::TestNamespace.AsyncScopedService(); + await instance.InitAsync(); + return instance; + } + + private global::TestNamespace.SyncScopedService? _testNamespace_SyncScopedService; + private global::TestNamespace.SyncScopedService GetTestNamespace_SyncScopedService() + { + if(_testNamespace_SyncScopedService is not null) return _testNamespace_SyncScopedService; + + var instance = new global::TestNamespace.SyncScopedService(); + + _testNamespace_SyncScopedService = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.AsyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncScopedServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IAsyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncScopedServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.SyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncScopedService()), + new(new ServiceIdentifier(typeof(global::TestNamespace.ISyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncScopedService()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + DisposeService(_testNamespace_SyncScopedService); + DisposeService(_testNamespace_AsyncScopedService); + return; + } + + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + await DisposeServiceAsync(_testNamespace_SyncScopedService); + await DisposeServiceAsync(_testNamespace_AsyncScopedService); + return; + } + + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_ScopedAsyncInit_EagerScoped_FiresInScopeConstructor.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_ScopedAsyncInit_EagerScoped_FiresInScopeConstructor.verified.txt new file mode 100644 index 0000000..16af250 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_ScopedAsyncInit_EagerScoped_FiresInScopeConstructor.verified.txt @@ -0,0 +1,330 @@ +ο»Ώ// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + + // Initialize eager scoped services + _ = GetTestNamespace_AsyncScopedServiceAsync(); + GetTestNamespace_SyncScopedService(); + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_AsyncScopedService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_AsyncScopedServiceAsync() + { + if(_testNamespace_AsyncScopedService is not null) + return await _testNamespace_AsyncScopedService; + + _testNamespace_AsyncScopedService = CreateTestNamespace_AsyncScopedServiceAsync(); + return await _testNamespace_AsyncScopedService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_AsyncScopedServiceAsync() + { + var instance = new global::TestNamespace.AsyncScopedService(); + await instance.InitAsync(); + return instance; + } + + private global::TestNamespace.SyncScopedService _testNamespace_SyncScopedService = null!; + private global::TestNamespace.SyncScopedService GetTestNamespace_SyncScopedService() + { + if(_testNamespace_SyncScopedService is not null) return _testNamespace_SyncScopedService; + + var instance = new global::TestNamespace.SyncScopedService(); + + _testNamespace_SyncScopedService = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.AsyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncScopedServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IAsyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncScopedServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.SyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_SyncScopedService!), + new(new ServiceIdentifier(typeof(global::TestNamespace.ISyncScopedService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c._testNamespace_SyncScopedService!), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + DisposeService(_testNamespace_SyncScopedService); + DisposeService(_testNamespace_AsyncScopedService); + return; + } + + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + await DisposeServiceAsync(_testNamespace_SyncScopedService); + await DisposeServiceAsync(_testNamespace_AsyncScopedService); + return; + } + + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonAsyncInit_EagerNone_DoesNotFireInConstructor.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonAsyncInit_EagerNone_DoesNotFireInConstructor.verified.txt new file mode 100644 index 0000000..159b103 --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonAsyncInit_EagerNone_DoesNotFireInConstructor.verified.txt @@ -0,0 +1,328 @@ +ο»Ώ// +#nullable enable +#pragma warning disable SGIOCEXP001 + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using SourceGen.Ioc; + +namespace TestNamespace; + +partial class TestContainer : IIocContainer, IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, ISupportRequiredService, IServiceScopeFactory, IServiceScope, IDisposable, IAsyncDisposable, IServiceProviderFactory +{ + private readonly IServiceProvider? _fallbackProvider; + private readonly bool _isRootScope = true; + private int _disposed; + + #region Constructors + + /// + /// Creates a new standalone container without external service provider fallback. + /// + public TestContainer() : this((IServiceProvider?)null) { } + + /// + /// Creates a new container with optional fallback to external service provider. + /// + /// Optional external service provider for unknown dependencies. + public TestContainer(IServiceProvider? fallbackProvider) + { + _fallbackProvider = fallbackProvider; + } + + private TestContainer(TestContainer parent) + { + _fallbackProvider = parent._fallbackProvider; + _isRootScope = false; + _testNamespace_AsyncService = parent._testNamespace_AsyncService; + _testNamespace_SyncService = parent._testNamespace_SyncService; + } + + #endregion + + #region Service Resolution + + private global::System.Threading.Tasks.Task? _testNamespace_AsyncService; + + private async global::System.Threading.Tasks.Task GetTestNamespace_AsyncServiceAsync() + { + if(_testNamespace_AsyncService is not null) + return await _testNamespace_AsyncService; + + _testNamespace_AsyncService = CreateTestNamespace_AsyncServiceAsync(); + return await _testNamespace_AsyncService; + } + + private async global::System.Threading.Tasks.Task CreateTestNamespace_AsyncServiceAsync() + { + var instance = new global::TestNamespace.AsyncService(); + await instance.InitAsync(); + return instance; + } + + private global::TestNamespace.SyncService? _testNamespace_SyncService; + private global::TestNamespace.SyncService GetTestNamespace_SyncService() + { + if(_testNamespace_SyncService is not null) return _testNamespace_SyncService; + + var instance = new global::TestNamespace.SyncService(); + + _testNamespace_SyncService = instance; + return instance; + } + + #endregion + + #region IServiceProvider + + public object? GetService(Type serviceType) + { + ThrowIfDisposed(); + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver)) + return resolver(this); + + return _fallbackProvider?.GetService(serviceType); + } + + #endregion + + #region IKeyedServiceProvider + + public object? GetKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.TryGetValue(new ServiceIdentifier(serviceType, key), out var resolver)) + return resolver(this); + + return _fallbackProvider is IKeyedServiceProvider keyed ? keyed.GetKeyedService(serviceType, serviceKey) : null; + } + + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) + { + ThrowIfDisposed(); + return GetKeyedService(serviceType, serviceKey) ?? throw new InvalidOperationException($"No service for type '{serviceType}' with key '{serviceKey}' has been registered."); + } + + #endregion + + #region ISupportRequiredService + + public object GetRequiredService(Type serviceType) + { + ThrowIfDisposed(); + return GetService(serviceType) ?? throw new InvalidOperationException($"No service for type '{serviceType}' has been registered."); + } + + #endregion + + #region ServiceProvider Extensions + + public T? GetService() where T : class + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredService() where T : notnull + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetServices() + { + ThrowIfDisposed(); + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + public T? GetKeyedService(object? serviceKey) where T : class + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? resolver(this) as T + : null; + } + + public T GetRequiredKeyedService(object? serviceKey) where T : notnull + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(T), key), out var resolver) + ? (T)resolver(this) + : throw new InvalidOperationException($"No service for type '{typeof(T)}' with key '{serviceKey}' has been registered."); + } + + public System.Collections.Generic.IEnumerable GetKeyedServices(object? serviceKey) + { + ThrowIfDisposed(); + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + return _serviceResolvers.TryGetValue(new ServiceIdentifier(typeof(System.Collections.Generic.IEnumerable), key), out var resolver) + ? (System.Collections.Generic.IEnumerable)resolver(this) + : []; + } + + #endregion + + #region IServiceProviderIsService + + public bool IsService(Type serviceType) + { + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey))) return true; + + return _fallbackProvider is IServiceProviderIsService isService && isService.IsService(serviceType); + } + + public bool IsKeyedService(Type serviceType, object? serviceKey) + { + var key = serviceKey ?? global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey; + + if(_serviceResolvers.ContainsKey(new ServiceIdentifier(serviceType, key))) return true; + + return _fallbackProvider is IServiceProviderIsKeyedService isKeyed && isKeyed.IsKeyedService(serviceType, serviceKey); + } + + #endregion + + #region IServiceScopeFactory + + public IServiceScope CreateScope() + { + ThrowIfDisposed(); + return new TestContainer(this); + } + + public AsyncServiceScope CreateAsyncScope() => new(CreateScope()); + + IServiceProvider IServiceScope.ServiceProvider => this; + + #endregion + + #region IIocContainer + + public static IReadOnlyCollection>> Resolvers => _serviceResolvers; + + private static readonly KeyValuePair>[] _localResolvers = + [ + new(new ServiceIdentifier(typeof(IServiceProvider), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(IServiceScopeFactory), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.TestContainer), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c), + new(new ServiceIdentifier(typeof(global::TestNamespace.AsyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.IAsyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_AsyncServiceAsync()), + new(new ServiceIdentifier(typeof(global::TestNamespace.SyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncService()), + new(new ServiceIdentifier(typeof(global::TestNamespace.ISyncService), global::Microsoft.Extensions.DependencyInjection.KeyedService.AnyKey), static c => c.GetTestNamespace_SyncService()), + ]; + + private static readonly global::System.Collections.Frozen.FrozenDictionary> _serviceResolvers = _localResolvers.ToFrozenDictionary(); + + #endregion + + #region IServiceProviderFactory + + /// + /// Creates a new container builder (returns the same IServiceCollection). + /// + public IServiceCollection CreateBuilder(IServiceCollection services) => services; + + /// + /// Creates the service provider from the built IServiceCollection. + /// + public IServiceProvider CreateServiceProvider(IServiceCollection containerBuilder) + { + var fallbackProvider = global::Microsoft.Extensions.DependencyInjection.ServiceCollectionContainerBuilderExtensions.BuildServiceProvider(containerBuilder); + return new TestContainer(fallbackProvider); + } + + #endregion + + #region Disposal + + public void Dispose() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + DisposeService(_testNamespace_SyncService); + DisposeService(_testNamespace_AsyncService); + } + + public async ValueTask DisposeAsync() + { + if(Interlocked.Exchange(ref _disposed, 1) != 0) return; + + if(!_isRootScope) + { + return; + } + + await DisposeServiceAsync(_testNamespace_SyncService); + await DisposeServiceAsync(_testNamespace_AsyncService); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(_disposed != 0, GetType()); + } + + private static async ValueTask DisposeServiceAsync(object? service) + { + if(service is IAsyncDisposable asyncDisposable) await asyncDisposable.DisposeAsync(); + else if(service is IDisposable disposable) disposable.Dispose(); + } + + private static void DisposeService(object? service) + { + if(service is IDisposable disposable) disposable.Dispose(); + } + + private static async ValueTask DisposeServiceAsync(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + await DisposeServiceAsync(await task); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + private static void DisposeService(Task? task) + { + if(task is { IsCompletedSuccessfully: true }) + { + try + { + DisposeService(task.ConfigureAwait(false).GetAwaiter().GetResult()); + } + catch(Exception ex) + { + global::SourceGen.Ioc.IocContainerGlobalOptions.OnDisposeException?.Invoke(ex); + } + } + } + + #endregion +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromEagerInit.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonAsyncInit_EagerSingleton_FiresInConstructor.verified.txt similarity index 99% rename from src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromEagerInit.verified.txt rename to src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonAsyncInit_EagerSingleton_FiresInConstructor.verified.txt index 22559b8..ff2c9b7 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_AsyncService_ExcludedFromEagerInit.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_SingletonAsyncInit_EagerSingleton_FiresInConstructor.verified.txt @@ -35,7 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_SyncService = GetTestNamespace_SyncService(); + _ = GetTestNamespace_AsyncServiceAsync(); + GetTestNamespace_SyncService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt index 29bf117..81ae3ed 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.AsyncMethodInject_TransientWithAsyncInit_GeneratesAsyncCreateMethod.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Dependency = GetTestNamespace_Dependency(); + GetTestNamespace_Dependency(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs index 722a86c..cb4ed36 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/AsyncMethodInjectTests.cs @@ -252,14 +252,13 @@ public partial class TestContainer { } } // ───────────────────────────────────────────────────────────────────────── - // Eager resolve exclusion β€” async-init services must NOT be in constructor init + // Async-init eager init β€” fire-and-forget startup follows EagerResolveOptions // ───────────────────────────────────────────────────────────────────────── [Test] - public async Task AsyncMethodInject_AsyncService_ExcludedFromEagerInit() + public async Task AsyncMethodInject_SingletonAsyncInit_EagerSingleton_FiresInConstructor() { - // EagerResolveOptions.Singleton is set, but the async-init service must NOT be eager. - // The sync-only SyncService IS eager. Verify that only SyncService appears in the ctor. + // EagerResolveOptions.Singleton eagerly starts async-init singletons in the root ctor. const string source = """ using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -297,6 +296,129 @@ public partial class TestContainer { } await Verify(generatedSource); } + [Test] + public async Task AsyncMethodInject_SingletonAsyncInit_EagerNone_DoesNotFireInConstructor() + { + // EagerResolveOptions.None keeps async-init singletons lazy too. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IAsyncService { } + public interface ISyncService { } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IAsyncService)])] + public class AsyncService : IAsyncService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(ISyncService)])] + public class SyncService : ISyncService { } + + [IocContainer(EagerResolveOptions = EagerResolveOptions.None, ThreadSafeStrategy = ThreadSafeStrategy.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_ScopedAsyncInit_EagerScoped_FiresInScopeConstructor() + { + // EagerResolveOptions.Scoped eagerly starts async-init scoped services in child scopes. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IAsyncScopedService { } + public interface ISyncScopedService { } + + [IocRegister(Lifetime = ServiceLifetime.Scoped, ServiceTypes = [typeof(IAsyncScopedService)])] + public class AsyncScopedService : IAsyncScopedService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocRegister(Lifetime = ServiceLifetime.Scoped, ServiceTypes = [typeof(ISyncScopedService)])] + public class SyncScopedService : ISyncScopedService { } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.Scoped)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + + [Test] + public async Task AsyncMethodInject_ScopedAsyncInit_EagerNone_DoesNotFireInScopeConstructor() + { + // EagerResolveOptions.None keeps async-init scoped services lazy. + const string source = """ + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IAsyncScopedService { } + public interface ISyncScopedService { } + + [IocRegister(Lifetime = ServiceLifetime.Scoped, ServiceTypes = [typeof(IAsyncScopedService)])] + public class AsyncScopedService : IAsyncScopedService + { + [IocInject] + public async Task InitAsync() { } + } + + [IocRegister(Lifetime = ServiceLifetime.Scoped, ServiceTypes = [typeof(ISyncScopedService)])] + public class SyncScopedService : ISyncScopedService { } + + [IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.None)] + public partial class TestContainer { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator( + source, + analyzerConfigOptions: new Dictionary + { + ["build_property.SourceGenIocFeatures"] = AsyncMethodInjectFeatures + }); + + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "Container.g.cs"); + + await Verify(generatedSource); + } + // ───────────────────────────────────────────────────────────────────────── // Collection exclusion β€” async-init services must NOT appear in collection resolvers // ───────────────────────────────────────────────────────────────────────── diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InGlobalNamespace_GeneratesWithoutNamespace.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InGlobalNamespace_GeneratesWithoutNamespace.verified.txt index fb3ffcf..301b41f 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InGlobalNamespace_GeneratesWithoutNamespace.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InGlobalNamespace_GeneratesWithoutNamespace.verified.txt @@ -33,7 +33,7 @@ partial class GlobalContainer : IIocContainer, IService _fallbackProvider = fallbackProvider; // Initialize eager singletons - _myService = GetMyService(); + GetMyService(); } private GlobalContainer(GlobalContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InNestedNamespace_GeneratesCorrectNamespace.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InNestedNamespace_GeneratesCorrectNamespace.verified.txt index c8cd2da..06ac869 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InNestedNamespace_GeneratesCorrectNamespace.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_InNestedNamespace_GeneratesCorrectNamespace.verified.txt @@ -35,7 +35,7 @@ partial class DeepNamespaceContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); - _testNamespace_Repository = GetTestNamespace_Repository(); + GetTestNamespace_Logger(); + GetTestNamespace_Repository(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_WithMultipleLifetimes_GeneratesCorrectContainer.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_WithMultipleLifetimes_GeneratesCorrectContainer.verified.txt index 82405d0..5279a94 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_WithMultipleLifetimes_GeneratesCorrectContainer.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.Container_WithMultipleLifetimes_GeneratesCorrectContainer.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_SingletonService = GetTestNamespace_SingletonService(); + GetTestNamespace_SingletonService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.SimpleContainer_GeneratesBasicContainer.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.SimpleContainer_GeneratesBasicContainer.verified.txt index 4ec4196..6f7e170 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.SimpleContainer_GeneratesBasicContainer.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/BasicContainerTests.SimpleContainer_GeneratesBasicContainer.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesEnumerableService.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesEnumerableService.verified.txt index ef0362a..adb26f4 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesEnumerableService.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesEnumerableService.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Plugin1 = GetTestNamespace_Plugin1(); - _testNamespace_Plugin2 = GetTestNamespace_Plugin2(); - _testNamespace_Plugin3 = GetTestNamespace_Plugin3(); + GetTestNamespace_Plugin1(); + GetTestNamespace_Plugin2(); + GetTestNamespace_Plugin3(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesReadOnlyCollectionAndArrayServices.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesReadOnlyCollectionAndArrayServices.verified.txt index 4f76e1d..71c2cc2 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesReadOnlyCollectionAndArrayServices.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/CollectionResolutionContainerTests.Container_WithCollectionResolution_GeneratesReadOnlyCollectionAndArrayServices.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler1 = GetTestNamespace_Handler1(); - _testNamespace_Handler2 = GetTestNamespace_Handler2(); + GetTestNamespace_Handler1(); + GetTestNamespace_Handler2(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithExplicitOnlyAndIncludeTags_ExplicitOnlyTakesPrecedence.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithExplicitOnlyAndIncludeTags_ExplicitOnlyTakesPrecedence.verified.txt index c673018..4addbeb 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithExplicitOnlyAndIncludeTags_ExplicitOnlyTakesPrecedence.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithExplicitOnlyAndIncludeTags_ExplicitOnlyTakesPrecedence.verified.txt @@ -35,7 +35,7 @@ partial class ExplicitContainer : IIocContainer { // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithMultipleInterfaceRegistration_GeneratesAllResolutions.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithMultipleInterfaceRegistration_GeneratesAllResolutions.verified.txt index e67835c..949e3df 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithMultipleInterfaceRegistration_GeneratesAllResolutions.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithMultipleInterfaceRegistration_GeneratesAllResolutions.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_ReadWriteService = GetTestNamespace_ReadWriteService(); + GetTestNamespace_ReadWriteService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithoutDIPackage_DoesNotGenerateIServiceProviderFactory.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithoutDIPackage_DoesNotGenerateIServiceProviderFactory.verified.txt index a21e3bd..27a1d74 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithoutDIPackage_DoesNotGenerateIServiceProviderFactory.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ContainerOptionsTests.Container_WithoutDIPackage_DoesNotGenerateIServiceProviderFactory.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecoratorInjectionMembers_GeneratesPropertyAndMethodInjection.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecoratorInjectionMembers_GeneratesPropertyAndMethodInjection.verified.txt index c4f743a..33ed0cd 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecoratorInjectionMembers_GeneratesPropertyAndMethodInjection.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecoratorInjectionMembers_GeneratesPropertyAndMethodInjection.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MemoryCache = GetTestNamespace_MemoryCache(); - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_MemoryCache(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecorators_HavingClosedGenericDependency_GeneratesDirectResolverCalls.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecorators_HavingClosedGenericDependency_GeneratesDirectResolverCalls.verified.txt index 437fe67..65f8cdd 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecorators_HavingClosedGenericDependency_GeneratesDirectResolverCalls.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/DecoratorContainerTests.Container_WithDecorators_HavingClosedGenericDependency_GeneratesDirectResolverCalls.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_TestHandler = GetTestNamespace_TestHandler(); - _testNamespace_Logger_TestNamespace_HandlerDecorator_TestNamespace_TestRequest__System_Collections_Generic_List_string___ = GetTestNamespace_Logger_TestNamespace_HandlerDecorator_TestNamespace_TestRequest__System_Collections_Generic_List_string___(); + GetTestNamespace_TestHandler(); + GetTestNamespace_Logger_TestNamespace_HandlerDecorator_TestNamespace_TestRequest__System_Collections_Generic_List_string___(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithDefaultEagerResolveOptions_EagerSingletons.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithDefaultEagerResolveOptions_EagerSingletons.verified.txt index 4ec4196..6f7e170 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithDefaultEagerResolveOptions_EagerSingletons.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithDefaultEagerResolveOptions_EagerSingletons.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsScoped_EagerScoped.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsScoped_EagerScoped.verified.txt index eea4edd..3c7c3d1 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsScoped_EagerScoped.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsScoped_EagerScoped.verified.txt @@ -42,7 +42,7 @@ partial class TestContainer : IIocContainer _testNamespace_SingletonService = parent._testNamespace_SingletonService; // Initialize eager scoped services - _testNamespace_ScopedService = GetTestNamespace_ScopedService(); + GetTestNamespace_ScopedService(); } #endregion diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsSingletonAndScoped_EagerBoth.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsSingletonAndScoped_EagerBoth.verified.txt index 109aded..5fe4c22 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsSingletonAndScoped_EagerBoth.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerResolveOptionsSingletonAndScoped_EagerBoth.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_SingletonService = GetTestNamespace_SingletonService(); + GetTestNamespace_SingletonService(); } private TestContainer(TestContainer parent) @@ -45,7 +45,7 @@ partial class TestContainer : IIocContainer _testNamespace_SingletonService = parent._testNamespace_SingletonService; // Initialize eager scoped services - _testNamespace_ScopedService = GetTestNamespace_ScopedService(); + GetTestNamespace_ScopedService(); } #endregion diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerSingletonAndDependencies_ResolvesDependenciesFirst.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerSingletonAndDependencies_ResolvesDependenciesFirst.verified.txt index 2af1441..bfc786d 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerSingletonAndDependencies_ResolvesDependenciesFirst.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/EagerResolveOptionsTests.Container_WithEagerSingletonAndDependencies_ResolvesDependenciesFirst.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Dependency = GetTestNamespace_Dependency(); - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_Dependency(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/FactoryAndInstanceContainerTests.Container_WithFactoryRegistration_UsesFactoryMethod.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/FactoryAndInstanceContainerTests.Container_WithFactoryRegistration_UsesFactoryMethod.verified.txt index ec7a30e..f7f2402 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/FactoryAndInstanceContainerTests.Container_WithFactoryRegistration_UsesFactoryMethod.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/FactoryAndInstanceContainerTests.Container_WithFactoryRegistration_UsesFactoryMethod.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Connection_TestNamespace_ConnectionFactory_Create = GetTestNamespace_Connection_TestNamespace_ConnectionFactory_Create(); + GetTestNamespace_Connection_TestNamespace_ConnectionFactory_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_AlsoSpecifiedImpls.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_AlsoSpecifiedImpls.verified.txt index c4840bd..a28e3ee 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_AlsoSpecifiedImpls.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_AlsoSpecifiedImpls.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler_TestNamespace_Entity_ = GetTestNamespace_Handler_TestNamespace_Entity_(); - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity2___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity2___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_Handler_TestNamespace_Entity_(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity2___TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_GeneratesCorrectFactoryCall.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_GeneratesCorrectFactoryCall.verified.txt index 9c3092b..4cb39d1 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_GeneratesCorrectFactoryCall.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_GeneratesCorrectFactoryCall.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleDiscoveries_GeneratesMultipleFactoryCalls.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleDiscoveries_GeneratesMultipleFactoryCalls.verified.txt index fe13ca8..4431d12 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleDiscoveries_GeneratesMultipleFactoryCalls.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleDiscoveries_GeneratesMultipleFactoryCalls.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create(); - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_User___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_User___TestNamespace_FactoryContainer_Create(); - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Order___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Order___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_User___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Order___TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleTypeParameters_GeneratesCorrectFactoryCall.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleTypeParameters_GeneratesCorrectFactoryCall.verified.txt index e60b1bc..40e22f4 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleTypeParameters_GeneratesCorrectFactoryCall.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_MultipleTypeParameters_GeneratesCorrectFactoryCall.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___System_Collections_Generic_List_TestNamespace_Dto___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___System_Collections_Generic_List_TestNamespace_Dto___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___System_Collections_Generic_List_TestNamespace_Dto___TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_RegisterImplWin.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_RegisterImplWin.verified.txt index 4dc1462..31014e3 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_RegisterImplWin.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_RegisterImplWin.verified.txt @@ -35,10 +35,10 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler_TestNamespace_Entity_ = GetTestNamespace_Handler_TestNamespace_Entity_(); - _testNamespace_Handler_TestNamespace_Entity__TestNamespace_FactoryContainer_Create = GetTestNamespace_Handler_TestNamespace_Entity__TestNamespace_FactoryContainer_Create(); - _testNamespace_Handler_TestNamespace_Entity2_ = GetTestNamespace_Handler_TestNamespace_Entity2_(); - _testNamespace_Handler_TestNamespace_Entity2__TestNamespace_FactoryContainer_Create = GetTestNamespace_Handler_TestNamespace_Entity2__TestNamespace_FactoryContainer_Create(); + GetTestNamespace_Handler_TestNamespace_Entity_(); + GetTestNamespace_Handler_TestNamespace_Entity__TestNamespace_FactoryContainer_Create(); + GetTestNamespace_Handler_TestNamespace_Entity2_(); + GetTestNamespace_Handler_TestNamespace_Entity2__TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_ReversedTypeParameterMapping_GeneratesCorrectOrder.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_ReversedTypeParameterMapping_GeneratesCorrectOrder.verified.txt index 2fe2345..e5997a6 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_ReversedTypeParameterMapping_GeneratesCorrectOrder.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_ReversedTypeParameterMapping_GeneratesCorrectOrder.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___System_Collections_Generic_List_TestNamespace_Dto___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___System_Collections_Generic_List_TestNamespace_Dto___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___System_Collections_Generic_List_TestNamespace_Dto___TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_WithServiceProvider_GeneratesProviderParameter.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_WithServiceProvider_GeneratesProviderParameter.verified.txt index d1cafcd..38cc583 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_WithServiceProvider_GeneratesProviderParameter.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericFactoryContainerTests.Container_WithGenericFactory_WithServiceProvider_GeneratesProviderParameter.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create = GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create(); + GetTestNamespace_IRequestHandler_System_Threading_Tasks_Task_TestNamespace_Entity___TestNamespace_FactoryContainer_Create(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_NoIntegrateServiceProvider_FallbacksToProviderForGenericTypes.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_NoIntegrateServiceProvider_FallbacksToProviderForGenericTypes.verified.txt index 39f903e..1bde5ff 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_NoIntegrateServiceProvider_FallbacksToProviderForGenericTypes.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_NoIntegrateServiceProvider_FallbacksToProviderForGenericTypes.verified.txt @@ -27,8 +27,8 @@ partial class TestContainer : IIocContainer { // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Logger_TestNamespace_MyService_ = GetTestNamespace_Logger_TestNamespace_MyService_(); + GetTestNamespace_MyService(); + GetTestNamespace_Logger_TestNamespace_MyService_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_UseSwitchStatement_FallbacksToProviderForGenericTypes.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_UseSwitchStatement_FallbacksToProviderForGenericTypes.verified.txt index 07306aa..66f72b8 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_UseSwitchStatement_FallbacksToProviderForGenericTypes.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/GenericServiceContainerTests.Container_WithOpenGenerics_UseSwitchStatement_FallbacksToProviderForGenericTypes.verified.txt @@ -27,7 +27,7 @@ partial class TestContainer : IIocContainer { // Initialize eager singletons - _testNamespace_UserService = GetTestNamespace_UserService(); + GetTestNamespace_UserService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_FieldInjectFeatureDisabled_IgnoresFieldInjection.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_FieldInjectFeatureDisabled_IgnoresFieldInjection.verified.txt index f563bff..32f109d 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_FieldInjectFeatureDisabled_IgnoresFieldInjection.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_FieldInjectFeatureDisabled_IgnoresFieldInjection.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithFieldInjection_GeneratesFieldAssignment.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithFieldInjection_GeneratesFieldAssignment.verified.txt index f5d89eb..505d0c9 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithFieldInjection_GeneratesFieldAssignment.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithFieldInjection_GeneratesFieldAssignment.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithMethodInjection_GeneratesMethodCall.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithMethodInjection_GeneratesMethodCall.verified.txt index fc66154..204e153 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithMethodInjection_GeneratesMethodCall.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithMethodInjection_GeneratesMethodCall.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithPropertyInjection_GeneratesPropertyAssignment.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithPropertyInjection_GeneratesPropertyAssignment.verified.txt index b6b27b9..22aed7b 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithPropertyInjection_GeneratesPropertyAssignment.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/InjectionTests.Container_WithPropertyInjection_GeneratesPropertyAssignment.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_GeneratesKeyedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_GeneratesKeyedResolution.verified.txt index 98e2c69..26fde10 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_GeneratesKeyedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_GeneratesKeyedResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_KeyedService1__key1_ = GetTestNamespace_KeyedService1__key1_(); - _testNamespace_KeyedService2__key2_ = GetTestNamespace_KeyedService2__key2_(); + GetTestNamespace_KeyedService1__key1_(); + GetTestNamespace_KeyedService2__key2_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_UseSwitchStatement.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_UseSwitchStatement.verified.txt index 4294fc4..7c8f8d3 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_UseSwitchStatement.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithKeyedServices_UseSwitchStatement.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_KeyedService1__key1_ = GetTestNamespace_KeyedService1__key1_(); - _testNamespace_KeyedService2__key2_ = GetTestNamespace_KeyedService2__key2_(); + GetTestNamespace_KeyedService1__key1_(); + GetTestNamespace_KeyedService2__key2_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_InjectsKey.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_InjectsKey.verified.txt index 8dbfbe4..05b3c8a 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_InjectsKey.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_InjectsKey.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_KeyedService__factoryKey_ = GetTestNamespace_KeyedService__factoryKey_(); + GetTestNamespace_KeyedService__factoryKey_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_WithDependencies_InjectsCorrectly.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_WithDependencies_InjectsCorrectly.verified.txt index 42921d2..bb85659 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_WithDependencies_InjectsCorrectly.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_Factory_WithDependencies_InjectsCorrectly.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_InjectsRegistrationKey.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_InjectsRegistrationKey.verified.txt index 699e07d..4d4d7b7 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_InjectsRegistrationKey.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_InjectsRegistrationKey.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Dependency = GetTestNamespace_Dependency(); - _testNamespace_KeyedService__myKey_ = GetTestNamespace_KeyedService__myKey_(); + GetTestNamespace_Dependency(); + GetTestNamespace_KeyedService__myKey_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_InjectsKey.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_InjectsKey.verified.txt index f81f602..1164736 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_InjectsKey.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_InjectsKey.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_PrimaryService_TestNamespace_ServiceKey_Primary = GetTestNamespace_PrimaryService_TestNamespace_ServiceKey_Primary(); - _testNamespace_SecondaryService_TestNamespace_ServiceKey_Secondary = GetTestNamespace_SecondaryService_TestNamespace_ServiceKey_Secondary(); + GetTestNamespace_PrimaryService_TestNamespace_ServiceKey_Primary(); + GetTestNamespace_SecondaryService_TestNamespace_ServiceKey_Secondary(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_WithOtherParameters_InjectsCorrectly.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_WithOtherParameters_InjectsCorrectly.verified.txt index 2e3e56c..e8135e7 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_WithOtherParameters_InjectsCorrectly.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_MethodInjection_WithOtherParameters_InjectsCorrectly.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Logger = GetTestNamespace_Logger(); + GetTestNamespace_Logger(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_NullableKey_InjectsNullForNonKeyedService.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_NullableKey_InjectsNullForNonKeyedService.verified.txt index d7ec348..20b1174 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_NullableKey_InjectsNullForNonKeyedService.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/KeyedServiceTests.Container_WithServiceKeyAttribute_NullableKey_InjectsNullForNonKeyedService.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_CombinesServices.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_CombinesServices.verified.txt index 390599d..c7338b8 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_CombinesServices.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_CombinesServices.verified.txt @@ -38,7 +38,7 @@ partial class AppContainer : IIocContainer, IServi _sharedLib_SharedModule = new global::SharedLib.SharedModule(fallbackProvider); // Initialize eager singletons - _mainApp_LocalService = GetMainApp_LocalService(); + GetMainApp_LocalService(); } private AppContainer(AppContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_UseSwitchStatementIgnored_UsesFrozenDictionary.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_UseSwitchStatementIgnored_UsesFrozenDictionary.verified.txt index 390599d..c7338b8 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_UseSwitchStatementIgnored_UsesFrozenDictionary.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithImportedModule_UseSwitchStatementIgnored_UsesFrozenDictionary.verified.txt @@ -38,7 +38,7 @@ partial class AppContainer : IIocContainer, IServi _sharedLib_SharedModule = new global::SharedLib.SharedModule(fallbackProvider); // Initialize eager singletons - _mainApp_LocalService = GetMainApp_LocalService(); + GetMainApp_LocalService(); } private AppContainer(AppContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithMultipleImportedModules_CombinesAllServices.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithMultipleImportedModules_CombinesAllServices.verified.txt index a7076e3..35e435e 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithMultipleImportedModules_CombinesAllServices.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/ModuleImportContainerTests.Container_WithMultipleImportedModules_CombinesAllServices.verified.txt @@ -40,7 +40,7 @@ partial class AppContainer : IIocContainer, IServi _sharedLib2_SharedModule2 = new global::SharedLib2.SharedModule2(fallbackProvider); // Initialize eager singletons - _mainApp_LocalService = GetMainApp_LocalService(); + GetMainApp_LocalService(); } private AppContainer(AppContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialAccessor_MixedMethodsAndProperties.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialAccessor_MixedMethodsAndProperties.verified.txt index 2bfe45e..4279138 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialAccessor_MixedMethodsAndProperties.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialAccessor_MixedMethodsAndProperties.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_ServiceA = GetTestNamespace_ServiceA(); + GetTestNamespace_ServiceA(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_NamingConflict_RenamesInternalResolver.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_NamingConflict_RenamesInternalResolver.verified.txt index ee13368..71c5274 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_NamingConflict_RenamesInternalResolver.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_NamingConflict_RenamesInternalResolver.verified.txt @@ -33,7 +33,7 @@ partial class TestContainer : IIocContainer, IServiceProv _fallbackProvider = fallbackProvider; // Initialize eager singletons - _myService = GetMyService_Resolve(); + GetMyService_Resolve(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_ResolvesRegisteredService.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_ResolvesRegisteredService.verified.txt index 46a571f..63e5cce 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_ResolvesRegisteredService.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_ResolvesRegisteredService.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_WithKeyedService_ResolvesWithKey.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_WithKeyedService_ResolvesWithKey.verified.txt index 6688b06..f345960 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_WithKeyedService_ResolvesWithKey.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialMethod_WithKeyedService_ResolvesWithKey.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_RedisCache__redis_ = GetTestNamespace_RedisCache__redis_(); + GetTestNamespace_RedisCache__redis_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_ResolvesRegisteredService.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_ResolvesRegisteredService.verified.txt index f803dc8..729fa19 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_ResolvesRegisteredService.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_ResolvesRegisteredService.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); + GetTestNamespace_MyService(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_WithKeyedService_ResolvesWithKey.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_WithKeyedService_ResolvesWithKey.verified.txt index d899609..f982178 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_WithKeyedService_ResolvesWithKey.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/PartialAccessorTests.PartialProperty_WithKeyedService_ResolvesWithKey.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MemoryCache__memory_ = GetTestNamespace_MemoryCache__memory_(); + GetTestNamespace_MemoryCache__memory_(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.CollectionWrapperKind_HasCorrectWrapperKindOnCollectionTypes.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.CollectionWrapperKind_HasCorrectWrapperKindOnCollectionTypes.verified.txt index b2673cb..8a1bc00 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.CollectionWrapperKind_HasCorrectWrapperKindOnCollectionTypes.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.CollectionWrapperKind_HasCorrectWrapperKindOnCollectionTypes.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_ServiceA = GetTestNamespace_ServiceA(); - _testNamespace_ServiceB = GetTestNamespace_ServiceB(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_ServiceA(); + GetTestNamespace_ServiceB(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_GeneratesDictionaryResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_GeneratesDictionaryResolution.verified.txt index 13c3de1..b7bfaa6 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_GeneratesDictionaryResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_GeneratesDictionaryResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler1 = GetTestNamespace_Handler1(); - _testNamespace_HandlerRegistry = GetTestNamespace_HandlerRegistry(); + GetTestNamespace_Handler1(); + GetTestNamespace_HandlerRegistry(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt index 0beda2a..5c875b4 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.DictionaryDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler1__h1_ = GetTestNamespace_Handler1__h1_(); - _testNamespace_HandlerRegistry = GetTestNamespace_HandlerRegistry(); + GetTestNamespace_Handler1__h1_(); + GetTestNamespace_HandlerRegistry(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt index 3868ac3..f810a79 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.EnumerableTask_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_GeneratesFuncResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_GeneratesFuncResolution.verified.txt index 5529d84..4ea4654 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_GeneratesFuncResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_GeneratesFuncResolution.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Consumer(); // Initialize Func wrapper fields _func_TestNamespace_IMyService_TestNamespace_MyService = new global::System.Func(() => GetTestNamespace_MyService()); diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt index 8823585..18ccdf2 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithMultiParamFunc_InExplicitOnlyContainer_WhenOptional_FallsBackToServiceProvider.verified.txt @@ -35,7 +35,7 @@ partial class ExplicitContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithSingleInputParameter_MatchesConstructorParameterByType.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithSingleInputParameter_MatchesConstructorParameterByType.verified.txt index 55136b2..80ecd8a 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithSingleInputParameter_MatchesConstructorParameterByType.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithSingleInputParameter_MatchesConstructorParameterByType.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithUnmatchedInputType_IgnoresInputAndResolvesFromDi.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithUnmatchedInputType_IgnoresInputAndResolvesFromDi.verified.txt index d362375..08d78fb 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithUnmatchedInputType_IgnoresInputAndResolvesFromDi.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncDependency_WithUnmatchedInputType_IgnoresInputAndResolvesFromDi.verified.txt @@ -35,7 +35,7 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt index e10e813..3d2a7a2 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.FuncTask_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_GeneratesKvpResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_GeneratesKvpResolution.verified.txt index fc9929a..ee7c5eb 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_GeneratesKvpResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_GeneratesKvpResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Service1 = GetTestNamespace_Service1(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Service1(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt index 786d89c..9e5a887 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithKeyedServices_GeneratesKvpResolvers.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler1__handler1_ = GetTestNamespace_Handler1__handler1_(); - _testNamespace_Handler2__handler2_ = GetTestNamespace_Handler2__handler2_(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Handler1__handler1_(); + GetTestNamespace_Handler2__handler2_(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithMixedKeyTypes_OnlyGeneratesMatchingKvpResolvers.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithMixedKeyTypes_OnlyGeneratesMatchingKvpResolvers.verified.txt index 510ed41..ea75dff 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithMixedKeyTypes_OnlyGeneratesMatchingKvpResolvers.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.KeyValuePairDependency_WithMixedKeyTypes_OnlyGeneratesMatchingKvpResolvers.verified.txt @@ -35,11 +35,11 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Keyed__Key_ = GetTestNamespace_Keyed__Key_(); - _testNamespace_KeyedEnum_TestNamespace_KeyEnum_Key0 = GetTestNamespace_KeyedEnum_TestNamespace_KeyEnum_Key0(); - _testNamespace_KeyedCsharp_TestNamespace_KeyedExtensions_Key = GetTestNamespace_KeyedCsharp_TestNamespace_KeyedExtensions_Key(); - _testNamespace_StringConsumer = GetTestNamespace_StringConsumer(); - _testNamespace_EnumConsumer = GetTestNamespace_EnumConsumer(); + GetTestNamespace_Keyed__Key_(); + GetTestNamespace_KeyedEnum_TestNamespace_KeyEnum_Key0(); + GetTestNamespace_KeyedCsharp_TestNamespace_KeyedExtensions_Key(); + GetTestNamespace_StringConsumer(); + GetTestNamespace_EnumConsumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyDependency_GeneratesLazyResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyDependency_GeneratesLazyResolution.verified.txt index 30c6034..b42821d 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyDependency_GeneratesLazyResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyDependency_GeneratesLazyResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); // Initialize Lazy wrapper fields _lazy_TestNamespace_IMyService_TestNamespace_MyService = new global::System.Lazy(() => GetTestNamespace_MyService(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication); diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt index 6566b5b..bdfa853 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.LazyTask_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableFunc_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableFunc_GeneratesNestedResolution.verified.txt index 23e7d44..e7d2d42 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableFunc_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableFunc_GeneratesNestedResolution.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Plugin1 = GetTestNamespace_Plugin1(); - _testNamespace_Plugin2 = GetTestNamespace_Plugin2(); - _testNamespace_PluginManager = GetTestNamespace_PluginManager(); + GetTestNamespace_Plugin1(); + GetTestNamespace_Plugin2(); + GetTestNamespace_PluginManager(); // Initialize Func wrapper fields _func_TestNamespace_IPlugin_TestNamespace_Plugin1 = new global::System.Func(() => GetTestNamespace_Plugin1()); diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableLazy_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableLazy_GeneratesNestedResolution.verified.txt index 915b4cc..b932260 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableLazy_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedEnumerableLazy_GeneratesNestedResolution.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Plugin1 = GetTestNamespace_Plugin1(); - _testNamespace_Plugin2 = GetTestNamespace_Plugin2(); - _testNamespace_PluginManager = GetTestNamespace_PluginManager(); + GetTestNamespace_Plugin1(); + GetTestNamespace_Plugin2(); + GetTestNamespace_PluginManager(); // Initialize Lazy wrapper fields _lazy_TestNamespace_IPlugin_TestNamespace_Plugin1 = new global::System.Lazy(() => GetTestNamespace_Plugin1(), global::System.Threading.LazyThreadSafetyMode.ExecutionAndPublication); diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedFuncLazy_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedFuncLazy_GeneratesNestedResolution.verified.txt index a7bf0d3..53966fc 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedFuncLazy_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedFuncLazy_GeneratesNestedResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyEnumerable_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyEnumerable_GeneratesNestedResolution.verified.txt index c5ca313..fc5cf63 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyEnumerable_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyEnumerable_GeneratesNestedResolution.verified.txt @@ -35,9 +35,9 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Plugin1 = GetTestNamespace_Plugin1(); - _testNamespace_Plugin2 = GetTestNamespace_Plugin2(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_Plugin1(); + GetTestNamespace_Plugin2(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyFunc_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyFunc_GeneratesNestedResolution.verified.txt index 2b4c88a..f3b125b 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyFunc_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.NestedLazyFunc_GeneratesNestedResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.ReadOnlyDictionaryDependency_GeneratesDictionaryResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.ReadOnlyDictionaryDependency_GeneratesDictionaryResolution.verified.txt index 13c3de1..b7bfaa6 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.ReadOnlyDictionaryDependency_GeneratesDictionaryResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.ReadOnlyDictionaryDependency_GeneratesDictionaryResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_Handler1 = GetTestNamespace_Handler1(); - _testNamespace_HandlerRegistry = GetTestNamespace_HandlerRegistry(); + GetTestNamespace_Handler1(); + GetTestNamespace_HandlerRegistry(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt index ec13a60..8513f03 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskEnumerable_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt index 52cf738..0da54e4 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskFunc_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt index aeb281c..4d271ad 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TaskLazy_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt index 517b59c..e47be7c 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_EnumerableLazyFunc_FallsBackToServiceProvider.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt index 5bd9767..1ddbe5e 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_FuncLazyEnumerable_GeneratesNestedResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt index 6a10043..6fd8891 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt +++ b/src/Ioc/test/SourceGen.Ioc.Test/ContainerSourceGeneratorSnapshot/WrapperTypeDependencyTests.TripleNested_LazyFuncLazy_GeneratesNestedResolution.verified.txt @@ -35,8 +35,8 @@ partial class TestContainer : IIocContainer _fallbackProvider = fallbackProvider; // Initialize eager singletons - _testNamespace_MyService = GetTestNamespace_MyService(); - _testNamespace_Consumer = GetTestNamespace_Consumer(); + GetTestNamespace_MyService(); + GetTestNamespace_Consumer(); } private TestContainer(TestContainer parent) diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.Tags_SpecialCharacters_GeneratesEscapedTagLiterals.verified.txt b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.Tags_SpecialCharacters_GeneratesEscapedTagLiterals.verified.txt new file mode 100644 index 0000000..087f38b --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.Tags_SpecialCharacters_GeneratesEscapedTagLiterals.verified.txt @@ -0,0 +1,31 @@ +ο»Ώ// +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using System.Collections.Generic; +using System.Linq; + +namespace TestAssembly +{ + /// + /// Extension methods for registering services from TestAssembly. + /// + public static class TestAssemblyServiceCollectionExtensions + { + /// + /// Registers services. Services with tags are only registered when matching tags are passed. + /// + /// The service collection. + /// Optional tags to filter which services to register. + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddTestAssembly(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, params global::System.Collections.Generic.IEnumerable tags) + { + if (tags.Contains("a\"b") || tags.Contains("c\\d") || tags.Contains("line\nbreak")) + { + services.AddSingleton(); + services.AddSingleton((global::System.IServiceProvider sp) => sp.GetRequiredService()); + } + + return services; + } + } +} diff --git a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.cs b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.cs index 8186894..037de4b 100644 --- a/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.Test/RegisterSourceGeneratorSnapshot/TagsTests.cs @@ -101,4 +101,29 @@ public class BothTagsService : IBothTagsService { } await Verify(generatedSource); } + + [Test] + public async Task Tags_SpecialCharacters_GeneratesEscapedTagLiterals() + { + const string source = """ + using Microsoft.Extensions.DependencyInjection; + using SourceGen.Ioc; + + namespace TestNamespace; + + public interface IMyTaggedService { } + + [IocRegister( + Lifetime = ServiceLifetime.Singleton, + ServiceTypes = [typeof(IMyTaggedService)], + Tags = ["a\"b", "c\\d", "line\nbreak"])] + public class MyTaggedService : IMyTaggedService { } + """; + + var result = SourceGeneratorTestHelper.RunGenerator(source); + await result.VerifyCompilableAsync(); + var generatedSource = SourceGeneratorTestHelper.GetGeneratedSource(result, "ServiceRegistration"); + + await Verify(generatedSource); + } } diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/AsyncEagerResolveContainer.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/AsyncEagerResolveContainer.cs new file mode 100644 index 0000000..0a7baca --- /dev/null +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/TestCase/AsyncEagerResolveContainer.cs @@ -0,0 +1,44 @@ +namespace SourceGen.Ioc.TestAot.TestCase; + +public interface IAsyncEagerSingletonService +{ + bool IsInitialized { get; } +} + +public static class AsyncEagerSingletonProbe +{ + private static int constructedCount; + private static int initializeStartedCount; + + public static int ConstructedCount => global::System.Threading.Volatile.Read(ref constructedCount); + public static int InitializeStartedCount => global::System.Threading.Volatile.Read(ref initializeStartedCount); + + public static void Reset() + { + global::System.Threading.Interlocked.Exchange(ref constructedCount, 0); + global::System.Threading.Interlocked.Exchange(ref initializeStartedCount, 0); + } + + internal static void OnConstructed() => global::System.Threading.Interlocked.Increment(ref constructedCount); + + internal static void OnInitializeStarted() => global::System.Threading.Interlocked.Increment(ref initializeStartedCount); +} + +[IocRegister(Lifetime = ServiceLifetime.Singleton, ServiceTypes = [typeof(IAsyncEagerSingletonService)])] +public sealed class AsyncEagerSingletonService : IAsyncEagerSingletonService +{ + public AsyncEagerSingletonService() => AsyncEagerSingletonProbe.OnConstructed(); + + public bool IsInitialized { get; private set; } + + [IocInject] + public async Task InitializeAsync() + { + AsyncEagerSingletonProbe.OnInitializeStarted(); + await Task.CompletedTask; + IsInitialized = true; + } +} + +[IocContainer(ThreadSafeStrategy = ThreadSafeStrategy.None, EagerResolveOptions = EagerResolveOptions.Singleton)] +public sealed partial class AsyncEagerResolveContainer; \ No newline at end of file diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs index 9606721..951c263 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/AsyncInjectionTests.cs @@ -5,6 +5,7 @@ namespace SourceGen.Ioc.TestAot.Tests; /// Verifies that services with [IocInject] async Task methods are properly /// initialized before first use in both standalone and MS.Extensions.DI scenarios. ///
+[NotInParallel] public sealed class AsyncInjectionTests { #region Standalone Container Tests (via partial Task accessor on AsyncInjectionModule) diff --git a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs index a9f34e3..9156cf0 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestAot/Tests/EagerResolveTests.cs @@ -5,8 +5,39 @@ namespace SourceGen.Ioc.TestAot.Tests; /// EagerResolveOptions.SingletonAndScoped resolve singletons during construction /// rather than on first use. ///
+[NotInParallel] public sealed class EagerResolveTests { + [Test] + public async Task AsyncEagerResolveContainer_SingletonAsyncInit_StartsDuringContainerConstruction() + { + AsyncEagerSingletonProbe.Reset(); + + await using var container = new AsyncEagerResolveContainer(); + + await Assert.That(AsyncEagerSingletonProbe.ConstructedCount).IsEqualTo(1); + await Assert.That(AsyncEagerSingletonProbe.InitializeStartedCount).IsEqualTo(1); + await Assert.That(AsyncEagerSingletonProbe.ConstructedCount).IsEqualTo(1); + await Assert.That(AsyncEagerSingletonProbe.InitializeStartedCount).IsEqualTo(1); + } + + [Test] + public async Task AsyncInjectionContainer_EagerNone_DoesNotStartAsyncInitDuringContainerConstruction() + { + AsyncInitServiceProbe.Reset(); + + using var container = new AsyncInjectionContainer(); + + await Assert.That(AsyncInitServiceProbe.ConstructedCount).IsEqualTo(0); + await Assert.That(AsyncInitServiceProbe.InitializeStartedCount).IsEqualTo(0); + + var dependency = container.GetRequiredService(); + + await Assert.That(dependency).IsNotNull(); + await Assert.That(AsyncInitServiceProbe.ConstructedCount).IsEqualTo(0); + await Assert.That(AsyncInitServiceProbe.InitializeStartedCount).IsEqualTo(0); + } + [Test] public async Task EagerResolveContainer_Singleton_IsNotNullAfterConstruction() { diff --git a/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs b/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs index 572e169..c8e89fa 100644 --- a/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs +++ b/src/Ioc/test/SourceGen.Ioc.TestCase/AsyncInjection.cs @@ -7,14 +7,36 @@ public interface IAsyncInitService string? InitializedBy { get; } } +public static class AsyncInitServiceProbe +{ + private static int constructedCount; + private static int initializeStartedCount; + + public static int ConstructedCount => global::System.Threading.Volatile.Read(ref constructedCount); + public static int InitializeStartedCount => global::System.Threading.Volatile.Read(ref initializeStartedCount); + + public static void Reset() + { + global::System.Threading.Interlocked.Exchange(ref constructedCount, 0); + global::System.Threading.Interlocked.Exchange(ref initializeStartedCount, 0); + } + + internal static void OnConstructed() => global::System.Threading.Interlocked.Increment(ref constructedCount); + + internal static void OnInitializeStarted() => global::System.Threading.Interlocked.Increment(ref initializeStartedCount); +} + internal sealed class AsyncInitService : IAsyncInitService { + public AsyncInitService() => AsyncInitServiceProbe.OnConstructed(); + public bool IsInitialized { get; private set; } public string? InitializedBy { get; private set; } [IocInject] public async Task InitializeAsync(IInjectionDependency dep) { + AsyncInitServiceProbe.OnInitializeStarted(); await Task.CompletedTask; InitializedBy = dep.Name; IsInitialized = true;