Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,114 @@ func main() {
// [{human [{What is 1 + 1?}]} {ai [{1 + 1 equals 2.}]}]
}
```

## Conditional Edges

Conditional edges allow you to route to different nodes based on the current state. This is useful for building agents that need to make decisions.

```go
func main() {
g := graph.NewMessageGraph()

// Define nodes
g.AddNode("classifier", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return state, nil
})
g.AddNode("handler_a", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handled by A")), nil
})
g.AddNode("handler_b", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Handled by B")), nil
})

// Define routing function
routeByContent := func(ctx context.Context, state []llms.MessageContent) string {
for _, msg := range state {
for _, part := range msg.Parts {
if textPart, ok := part.(llms.TextContent); ok {
if strings.Contains(textPart.Text, "option_a") {
return "route_a"
}
}
}
}
return "route_b"
}

// Add conditional edge with routing function and mapping
g.AddConditionalEdges("classifier", routeByContent, map[string]string{
"route_a": "handler_a",
"route_b": "handler_b",
})

g.AddEdge("handler_a", graph.END)
g.AddEdge("handler_b", graph.END)
g.SetEntryPoint("classifier")

runnable, _ := g.Compile()
res, _ := runnable.Invoke(context.Background(), []llms.MessageContent{
llms.TextParts(schema.ChatMessageTypeHuman, "I want option_a"),
})

fmt.Println(res)
// Output: routes to handler_a
}
```

## Max Iterations

To prevent infinite loops in cyclic graphs, you can set a maximum number of iterations. The default is 25.

```go
g := graph.NewMessageGraph()
g.SetMaxIterations(100) // Allow up to 100 iterations
```

When the limit is exceeded, `Invoke` returns `graph.ErrMaxIterationsExceeded`.

## Fan-Out / Fan-In (Parallel Execution)

When multiple edges originate from a single node, the target nodes execute in parallel. Results are merged using a reducer function.

```go
g := graph.NewMessageGraph()

g.AddNode("router", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return state, nil
})
g.AddNode("branch_a", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Result A")), nil
})
g.AddNode("branch_b", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return append(state, llms.TextParts(schema.ChatMessageTypeAI, "Result B")), nil
})
g.AddNode("aggregator", func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error) {
return state, nil
})

// Fan-out: router -> branch_a AND branch_b (parallel)
g.AddEdge("router", "branch_a")
g.AddEdge("router", "branch_b")

// Fan-in: both branches -> aggregator
g.AddEdge("branch_a", "aggregator")
g.AddEdge("branch_b", "aggregator")

g.AddEdge("aggregator", graph.END)
g.SetEntryPoint("router")
```

### Custom Reducer

By default, results from parallel branches are concatenated. You can provide a custom reducer:

```go
g.SetReducer(func(results [][]llms.MessageContent) []llms.MessageContent {
// Custom merge logic
var merged []llms.MessageContent
for _, r := range results {
merged = append(merged, r...)
}
return merged
})
```
199 changes: 188 additions & 11 deletions graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync"

"github.com/tmc/langchaingo/llms"
)
Expand All @@ -20,6 +21,12 @@ var (

// ErrNoOutgoingEdge is returned when no outgoing edge is found for a node.
ErrNoOutgoingEdge = errors.New("no outgoing edge found for node")

// ErrInvalidConditionResult is returned when a condition function returns an unmapped result.
ErrInvalidConditionResult = errors.New("condition returned unmapped result")

// ErrMaxIterationsExceeded is returned when the graph exceeds the maximum number of iterations.
ErrMaxIterationsExceeded = errors.New("max iterations exceeded")
)

// Node represents a node in the message graph.
Expand All @@ -41,6 +48,36 @@ type Edge struct {
To string
}

// Reducer is a function that merges multiple state results into one.
type Reducer func(results [][]llms.MessageContent) []llms.MessageContent

// defaultReducer concatenates all message slices.
func defaultReducer(results [][]llms.MessageContent) []llms.MessageContent {
var merged []llms.MessageContent
for _, r := range results {
merged = append(merged, r...)
}
return merged
}

// ConditionalEdge represents a conditional edge that routes based on state.
type ConditionalEdge struct {
// From is the name of the node from which the edge originates.
From string

// Condition is a function that evaluates the current state and returns a routing key.
Condition func(ctx context.Context, state []llms.MessageContent) string

// Mapping maps routing keys returned by Condition to target node names.
Mapping map[string]string
}

// nodeResult holds the result of a parallel node execution.
type nodeResult struct {
state []llms.MessageContent
err error
}

// MessageGraph represents a message graph.
type MessageGraph struct {
// nodes is a map of node names to their corresponding Node objects.
Expand All @@ -49,17 +86,39 @@ type MessageGraph struct {
// edges is a slice of Edge objects representing the connections between nodes.
edges []Edge

// conditionalEdges is a map of source node names to their ConditionalEdge.
conditionalEdges map[string]ConditionalEdge

// entryPoint is the name of the entry point node in the graph.
entryPoint string

// maxIterations is the maximum number of iterations allowed before returning an error.
maxIterations int

// reducer is the function used to merge parallel execution results.
reducer Reducer
}

// NewMessageGraph creates a new instance of MessageGraph.
func NewMessageGraph() *MessageGraph {
return &MessageGraph{
nodes: make(map[string]Node),
nodes: make(map[string]Node),
conditionalEdges: make(map[string]ConditionalEdge),
maxIterations: 25,
reducer: defaultReducer,
}
}

// SetMaxIterations sets the maximum number of iterations allowed before returning an error.
func (g *MessageGraph) SetMaxIterations(n int) {
g.maxIterations = n
}

// SetReducer sets a custom reducer function for merging parallel execution results.
func (g *MessageGraph) SetReducer(fn Reducer) {
g.reducer = fn
}

// AddNode adds a new node to the message graph with the given name and function.
func (g *MessageGraph) AddNode(name string, fn func(ctx context.Context, state []llms.MessageContent) ([]llms.MessageContent, error)) {
g.nodes[name] = Node{
Expand All @@ -76,6 +135,29 @@ func (g *MessageGraph) AddEdge(from, to string) {
})
}

// AddConditionalEdges adds a conditional edge from a node.
// The condition function receives the current state and returns a routing key.
// The mapping maps routing keys to target node names.
//
// Example:
//
// g.AddConditionalEdges("classify", routeByCategory, map[string]string{
// "refund": "refund_handler",
// "billing": "billing_handler",
// "general": "general_handler",
// })
func (g *MessageGraph) AddConditionalEdges(
from string,
condition func(ctx context.Context, state []llms.MessageContent) string,
mapping map[string]string,
) {
g.conditionalEdges[from] = ConditionalEdge{
From: from,
Condition: condition,
Mapping: mapping,
}
}

// SetEntryPoint sets the entry point node name for the message graph.
func (g *MessageGraph) SetEntryPoint(name string) {
g.entryPoint = name
Expand All @@ -99,15 +181,101 @@ func (g *MessageGraph) Compile() (*Runnable, error) {
}, nil
}

// Invoke executes the compiled message graph with the given input messages.
// It returns the resulting messages and an error if any occurs during the execution.
// collectNextNodes returns all target nodes from a given node.
func (r *Runnable) collectNextNodes(currentNode string, ctx context.Context, state []llms.MessageContent) ([]string, error) {
var nextNodes []string

// Check conditional edges first
if condEdge, ok := r.graph.conditionalEdges[currentNode]; ok {
result := condEdge.Condition(ctx, state)
if target, ok := condEdge.Mapping[result]; ok {
nextNodes = append(nextNodes, target)
} else {
return nil, fmt.Errorf("%w: %s (got %q)", ErrInvalidConditionResult, currentNode, result)
}
}

// Collect all regular edges
for _, edge := range r.graph.edges {
if edge.From == currentNode {
nextNodes = append(nextNodes, edge.To)
}
}

return nextNodes, nil
}

// executeFanOut executes multiple nodes in parallel and merges their results.
func (r *Runnable) executeFanOut(ctx context.Context, state []llms.MessageContent, nodes []string) ([]llms.MessageContent, error) {
results := make(chan nodeResult, len(nodes))
var wg sync.WaitGroup

for _, nodeName := range nodes {
wg.Add(1)
go func(name string) {
defer wg.Done()

node, ok := r.graph.nodes[name]
if !ok {
results <- nodeResult{err: fmt.Errorf("%w: %s", ErrNodeNotFound, name)}
return
}

result, err := node.Function(ctx, state)
results <- nodeResult{state: result, err: err}
}(nodeName)
}

go func() {
wg.Wait()
close(results)
}()

var allResults [][]llms.MessageContent
for res := range results {
if res.err != nil {
return nil, res.err
}
allResults = append(allResults, res.state)
}

return r.graph.reducer(allResults), nil
}

// findFanInNode finds the common target node that all parallel branches point to.
func (r *Runnable) findFanInNode(parallelNodes []string) string {
targetCounts := make(map[string]int)

for _, nodeName := range parallelNodes {
for _, edge := range r.graph.edges {
if edge.From == nodeName {
targetCounts[edge.To]++
}
}
}

for target, count := range targetCounts {
if count == len(parallelNodes) {
return target
}
}

return END
}

// Invoke executes the compiled message graph with the given input messages.
// It returns the resulting messages and an error if any occurs during the execution.
func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) ([]llms.MessageContent, error) {
state := messages
currentNode := r.graph.entryPoint
iterations := 0

for {
iterations++
if iterations > r.graph.maxIterations {
return nil, fmt.Errorf("%w: %d", ErrMaxIterationsExceeded, r.graph.maxIterations)
}

if currentNode == END {
break
}
Expand All @@ -123,18 +291,27 @@ func (r *Runnable) Invoke(ctx context.Context, messages []llms.MessageContent) (
return nil, fmt.Errorf("error in node %s: %w", currentNode, err)
}

foundNext := false
for _, edge := range r.graph.edges {
if edge.From == currentNode {
currentNode = edge.To
foundNext = true
break
}
// Collect all next nodes
nextNodes, err := r.collectNextNodes(currentNode, ctx, state)
if err != nil {
return nil, err
}

if !foundNext {
if len(nextNodes) == 0 {
return nil, fmt.Errorf("%w: %s", ErrNoOutgoingEdge, currentNode)
}

if len(nextNodes) == 1 {
// Single path - continue normally
currentNode = nextNodes[0]
} else {
// Multiple paths - fan-out execution
state, err = r.executeFanOut(ctx, state, nextNodes)
if err != nil {
return nil, err
}
currentNode = r.findFanInNode(nextNodes)
}
}

return state, nil
Expand Down
Loading