diff --git a/README.md b/README.md index 4a011e5..54427ec 100644 --- a/README.md +++ b/README.md @@ -66,3 +66,114 @@ func main() { // [{human [{What is 1 + 1?}]} {ai [{1 + 1 equals 2.}]}] } ``` + +## Conditional Edges + +Conditional edges allow you to route to different nodes based on the current state. This is useful for building agents that need to make decisions. + +```go +func main() { + g := graph.NewMessageGraph() + + // Define nodes + g.AddNode("classifier", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil + }) + g.AddNode("handler_a", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handled by A")), nil + }) + g.AddNode("handler_b", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handled by B")), nil + }) + + // Define routing function + routeByContent := func(ctx context.Context, state []llms.MessageContent) string { + for _, msg := range state { + for _, part := range msg.Parts { + if textPart, ok := part.(llms.TextContent); ok { + if strings.Contains(textPart.Text, "option_a") { + return "route_a" + } + } + } + } + return "route_b" + } + + // Add conditional edge with routing function and mapping + g.AddConditionalEdges("classifier", routeByContent, map[string]string{ + "route_a": "handler_a", + "route_b": "handler_b", + }) + + g.AddEdge("handler_a", graph.END) + g.AddEdge("handler_b", graph.END) + g.SetEntryPoint("classifier") + + runnable, _ := g.Compile() + res, _ := runnable.Invoke(context.Background(), []llms.MessageContent{ + llms.TextParts(schema.ChatMessageTypeHuman, "I want option_a"), + }) + + fmt.Println(res) + // Output: routes to handler_a +} +``` + +## Max Iterations + +To prevent infinite loops in cyclic graphs, you can set a maximum number of iterations. The default is 25. + +```go +g := graph.NewMessageGraph() +g.SetMaxIterations(100) // Allow up to 100 iterations +``` + +When the limit is exceeded, `Invoke` returns `graph.ErrMaxIterationsExceeded`. + +## Fan-Out / Fan-In (Parallel Execution) + +When multiple edges originate from a single node, the target nodes execute in parallel. Results are merged using a reducer function. + +```go +g := graph.NewMessageGraph() + +g.AddNode("router", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil +}) +g.AddNode("branch_a", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Result A")), nil +}) +g.AddNode("branch_b", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Result B")), nil +}) +g.AddNode("aggregator", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil +}) + +// Fan-out: router -> branch_a AND branch_b (parallel) +g.AddEdge("router", "branch_a") +g.AddEdge("router", "branch_b") + +// Fan-in: both branches -> aggregator +g.AddEdge("branch_a", "aggregator") +g.AddEdge("branch_b", "aggregator") + +g.AddEdge("aggregator", graph.END) +g.SetEntryPoint("router") +``` + +### Custom Reducer + +By default, results from parallel branches are concatenated. You can provide a custom reducer: + +```go +g.SetReducer(func(results [][]llms.MessageContent) []llms.MessageContent { + // Custom merge logic + var merged []llms.MessageContent + for _, r := range results { + merged = append(merged, r...) + } + return merged +}) +``` diff --git a/graph/graph.go b/graph/graph.go index 421b69e..e6f6928 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "github.com/tmc/langchaingo/llms" ) @@ -20,6 +21,12 @@ var ( // ErrNoOutgoingEdge is returned when no outgoing edge is found for a node. ErrNoOutgoingEdge = errors.New("no outgoing edge found for node") + + // ErrInvalidConditionResult is returned when a condition function returns an unmapped result. + ErrInvalidConditionResult = errors.New("condition returned unmapped result") + + // ErrMaxIterationsExceeded is returned when the graph exceeds the maximum number of iterations. + ErrMaxIterationsExceeded = errors.New("max iterations exceeded") ) // Node represents a node in the message graph. @@ -41,6 +48,36 @@ type Edge struct { To string } +// Reducer is a function that merges multiple state results into one. +type Reducer func(results [][]llms.MessageContent) []llms.MessageContent + +// defaultReducer concatenates all message slices. +func defaultReducer(results [][]llms.MessageContent) []llms.MessageContent { + var merged []llms.MessageContent + for _, r := range results { + merged = append(merged, r...) + } + return merged +} + +// ConditionalEdge represents a conditional edge that routes based on state. +type ConditionalEdge struct { + // From is the name of the node from which the edge originates. + From string + + // Condition is a function that evaluates the current state and returns a routing key. + Condition func(ctx context.Context, state []llms.MessageContent) string + + // Mapping maps routing keys returned by Condition to target node names. + Mapping map[string]string +} + +// nodeResult holds the result of a parallel node execution. +type nodeResult struct { + state []llms.MessageContent + err error +} + // MessageGraph represents a message graph. type MessageGraph struct { // nodes is a map of node names to their corresponding Node objects. @@ -49,17 +86,39 @@ type MessageGraph struct { // edges is a slice of Edge objects representing the connections between nodes. edges []Edge + // conditionalEdges is a map of source node names to their ConditionalEdge. + conditionalEdges map[string]ConditionalEdge + // entryPoint is the name of the entry point node in the graph. entryPoint string + + // maxIterations is the maximum number of iterations allowed before returning an error. + maxIterations int + + // reducer is the function used to merge parallel execution results. + reducer Reducer } // NewMessageGraph creates a new instance of MessageGraph. func NewMessageGraph() *MessageGraph { return &MessageGraph{ - nodes: make(map[string]Node), + nodes: make(map[string]Node), + conditionalEdges: make(map[string]ConditionalEdge), + maxIterations: 25, + reducer: defaultReducer, } } +// SetMaxIterations sets the maximum number of iterations allowed before returning an error. +func (g *MessageGraph) SetMaxIterations(n int) { + g.maxIterations = n +} + +// SetReducer sets a custom reducer function for merging parallel execution results. +func (g *MessageGraph) SetReducer(fn Reducer) { + g.reducer = fn +} + // AddNode adds a new node to the message graph with the given name and function. func (g *MessageGraph) AddNode(name string, fn func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error)) { g.nodes[name] = Node{ @@ -76,6 +135,29 @@ func (g *MessageGraph) AddEdge(from, to string) { }) } +// AddConditionalEdges adds a conditional edge from a node. +// The condition function receives the current state and returns a routing key. +// The mapping maps routing keys to target node names. +// +// Example: +// +// g.AddConditionalEdges("classify", routeByCategory, map[string]string{ +// "refund": "refund_handler", +// "billing": "billing_handler", +// "general": "general_handler", +// }) +func (g *MessageGraph) AddConditionalEdges( + from string, + condition func(ctx context.Context, state []llms.MessageContent) string, + mapping map[string]string, +) { + g.conditionalEdges[from] = ConditionalEdge{ + From: from, + Condition: condition, + Mapping: mapping, + } +} + // SetEntryPoint sets the entry point node name for the message graph. func (g *MessageGraph) SetEntryPoint(name string) { g.entryPoint = name @@ -99,15 +181,101 @@ func (g *MessageGraph) Compile() (*Runnable, error) { }, nil } -// Invoke executes the compiled message graph with the given input messages. -// It returns the resulting messages and an error if any occurs during the execution. +// collectNextNodes returns all target nodes from a given node. +func (r *Runnable) collectNextNodes(currentNode string, ctx context.Context, state []llms.MessageContent) ([]string, error) { + var nextNodes []string + + // Check conditional edges first + if condEdge, ok := r.graph.conditionalEdges[currentNode]; ok { + result := condEdge.Condition(ctx, state) + if target, ok := condEdge.Mapping[result]; ok { + nextNodes = append(nextNodes, target) + } else { + return nil, fmt.Errorf("%w: %s (got %q)", ErrInvalidConditionResult, currentNode, result) + } + } + + // Collect all regular edges + for _, edge := range r.graph.edges { + if edge.From == currentNode { + nextNodes = append(nextNodes, edge.To) + } + } + + return nextNodes, nil +} + +// executeFanOut executes multiple nodes in parallel and merges their results. +func (r *Runnable) executeFanOut(ctx context.Context, state []llms.MessageContent, nodes []string) ([]llms.MessageContent, error) { + results := make(chan nodeResult, len(nodes)) + var wg sync.WaitGroup + + for _, nodeName := range nodes { + wg.Add(1) + go func(name string) { + defer wg.Done() + + node, ok := r.graph.nodes[name] + if !ok { + results <- nodeResult{err: fmt.Errorf("%w: %s", ErrNodeNotFound, name)} + return + } + + result, err := node.Function(ctx, state) + results <- nodeResult{state: result, err: err} + }(nodeName) + } + + go func() { + wg.Wait() + close(results) + }() + + var allResults [][]llms.MessageContent + for res := range results { + if res.err != nil { + return nil, res.err + } + allResults = append(allResults, res.state) + } + + return r.graph.reducer(allResults), nil +} + +// findFanInNode finds the common target node that all parallel branches point to. +func (r *Runnable) findFanInNode(parallelNodes []string) string { + targetCounts := make(map[string]int) + + for _, nodeName := range parallelNodes { + for _, edge := range r.graph.edges { + if edge.From == nodeName { + targetCounts[edge.To]++ + } + } + } + + for target, count := range targetCounts { + if count == len(parallelNodes) { + return target + } + } + + return END +} + // Invoke executes the compiled message graph with the given input messages. // It returns the resulting messages and an error if any occurs during the execution. func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ([]llms.MessageContent, error) { state := messages currentNode := r.graph.entryPoint + iterations := 0 for { + iterations++ + if iterations > r.graph.maxIterations { + return nil, fmt.Errorf("%w: %d", ErrMaxIterationsExceeded, r.graph.maxIterations) + } + if currentNode == END { break } @@ -123,18 +291,27 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ( return nil, fmt.Errorf("error in node %s: %w", currentNode, err) } - foundNext := false - for _, edge := range r.graph.edges { - if edge.From == currentNode { - currentNode = edge.To - foundNext = true - break - } + // Collect all next nodes + nextNodes, err := r.collectNextNodes(currentNode, ctx, state) + if err != nil { + return nil, err } - if !foundNext { + if len(nextNodes) == 0 { return nil, fmt.Errorf("%w: %s", ErrNoOutgoingEdge, currentNode) } + + if len(nextNodes) == 1 { + // Single path - continue normally + currentNode = nextNodes[0] + } else { + // Multiple paths - fan-out execution + state, err = r.executeFanOut(ctx, state, nextNodes) + if err != nil { + return nil, err + } + currentNode = r.findFanInNode(nextNodes) + } } return state, nil diff --git a/graph/graph_test.go b/graph/graph_test.go index 1e04d8a..53060a9 100644 --- a/graph/graph_test.go +++ b/graph/graph_test.go @@ -137,6 +137,108 @@ func TestMessageGraph(t *testing.T) { }, expectedError: errors.New("error in node node1: node error"), }, + { + name: "Conditional edge - route to handler_a", + buildGraph: func() *graph.MessageGraph { + g := graph.NewMessageGraph() + g.AddNode("classifier", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "classified")), nil + }) + g.AddNode("handler_a", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handler A")), nil + }) + g.AddNode("handler_b", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handler B")), nil + }) + g.AddConditionalEdges("classifier", func(_ context.Context, _ []llms.MessageContent) string { + return "option_a" + }, map[string]string{ + "option_a": "handler_a", + "option_b": "handler_b", + }) + g.AddEdge("handler_a", graph.END) + g.AddEdge("handler_b", graph.END) + g.SetEntryPoint("classifier") + return g + }, + inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")}, + expectedOutput: []llms.MessageContent{ + llms.TextParts(schema.ChatMessageTypeHuman, "Input"), + llms.TextParts(schema.ChatMessageTypeAI, "classified"), + llms.TextParts(schema.ChatMessageTypeAI, "Handler A"), + }, + expectedError: nil, + }, + { + name: "Conditional edge - route to handler_b", + buildGraph: func() *graph.MessageGraph { + g := graph.NewMessageGraph() + g.AddNode("classifier", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "classified")), nil + }) + g.AddNode("handler_a", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handler A")), nil + }) + g.AddNode("handler_b", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handler B")), nil + }) + g.AddConditionalEdges("classifier", func(_ context.Context, _ []llms.MessageContent) string { + return "option_b" + }, map[string]string{ + "option_a": "handler_a", + "option_b": "handler_b", + }) + g.AddEdge("handler_a", graph.END) + g.AddEdge("handler_b", graph.END) + g.SetEntryPoint("classifier") + return g + }, + inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")}, + expectedOutput: []llms.MessageContent{ + llms.TextParts(schema.ChatMessageTypeHuman, "Input"), + llms.TextParts(schema.ChatMessageTypeAI, "classified"), + llms.TextParts(schema.ChatMessageTypeAI, "Handler B"), + }, + expectedError: nil, + }, + { + name: "Conditional edge - invalid condition result", + buildGraph: func() *graph.MessageGraph { + g := graph.NewMessageGraph() + g.AddNode("classifier", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil + }) + g.AddNode("handler_a", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil + }) + g.AddConditionalEdges("classifier", func(_ context.Context, _ []llms.MessageContent) string { + return "unknown_route" + }, map[string]string{ + "option_a": "handler_a", + }) + g.SetEntryPoint("classifier") + return g + }, + inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")}, + expectedOutput: nil, + expectedError: fmt.Errorf("%w: classifier (got %q)", graph.ErrInvalidConditionResult, "unknown_route"), + }, + { + name: "Max iterations exceeded", + buildGraph: func() *graph.MessageGraph { + g := graph.NewMessageGraph() + g.SetMaxIterations(3) + g.AddNode("node1", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil + }) + g.AddEdge("node1", "node1") // cycle + g.SetEntryPoint("node1") + return g + }, + inputMessages: []llms.MessageContent{llms.TextParts(schema.ChatMessageTypeHuman, "Input")}, + expectedOutput: nil, + expectedError: fmt.Errorf("%w: 3", graph.ErrMaxIterationsExceeded), + }, } for _, tc := range testCases { @@ -177,3 +279,114 @@ func TestMessageGraph(t *testing.T) { }) } } + +func TestFanOutExecution(t *testing.T) { + t.Parallel() + + g := graph.NewMessageGraph() + g.AddNode("router", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil + }) + g.AddNode("branch_a", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Branch A")), nil + }) + g.AddNode("branch_b", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Branch B")), nil + }) + g.AddNode("aggregator", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Aggregated")), nil + }) + g.AddEdge("router", "branch_a") + g.AddEdge("router", "branch_b") + g.AddEdge("branch_a", "aggregator") + g.AddEdge("branch_b", "aggregator") + g.AddEdge("aggregator", graph.END) + g.SetEntryPoint("router") + + runnable, err := g.Compile() + if err != nil { + t.Fatalf("failed to compile: %v", err) + } + + result, err := runnable.Invoke(context.Background(), []llms.MessageContent{ + llms.TextParts(schema.ChatMessageTypeHuman, "Input"), + }) + if err != nil { + t.Fatalf("failed to invoke: %v", err) + } + + // Check that we have results from both branches (order may vary) + resultStr := fmt.Sprint(result) + if !stringContains(resultStr, "Branch A") { + t.Errorf("expected result to contain 'Branch A', got: %s", resultStr) + } + if !stringContains(resultStr, "Branch B") { + t.Errorf("expected result to contain 'Branch B', got: %s", resultStr) + } + if !stringContains(resultStr, "Aggregated") { + t.Errorf("expected result to contain 'Aggregated', got: %s", resultStr) + } +} + +func TestCustomReducer(t *testing.T) { + t.Parallel() + + g := graph.NewMessageGraph() + + // Custom reducer that only keeps the first result + g.SetReducer(func(results [][]llms.MessageContent) []llms.MessageContent { + if len(results) > 0 { + return results[0] + } + return nil + }) + + g.AddNode("router", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return state, nil + }) + g.AddNode("branch_a", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Branch A")), nil + }) + g.AddNode("branch_b", func(_ context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) { + return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Branch B")), nil + }) + g.AddEdge("router", "branch_a") + g.AddEdge("router", "branch_b") + g.AddEdge("branch_a", graph.END) + g.AddEdge("branch_b", graph.END) + g.SetEntryPoint("router") + + runnable, err := g.Compile() + if err != nil { + t.Fatalf("failed to compile: %v", err) + } + + result, err := runnable.Invoke(context.Background(), []llms.MessageContent{ + llms.TextParts(schema.ChatMessageTypeHuman, "Input"), + }) + if err != nil { + t.Fatalf("failed to invoke: %v", err) + } + + // With custom reducer, we should only have one branch result + resultStr := fmt.Sprint(result) + hasBranchA := stringContains(resultStr, "Branch A") + hasBranchB := stringContains(resultStr, "Branch B") + + // Should have exactly one branch, not both + if hasBranchA && hasBranchB { + t.Errorf("expected only one branch result with custom reducer, got both: %s", resultStr) + } + if !hasBranchA && !hasBranchB { + t.Errorf("expected at least one branch result, got: %s", resultStr) + } +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}