You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
461 lines
14 KiB
Go
461 lines
14 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package visualize
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/chromedp/chromedp"
|
|
"github.com/cloudwego/eino/compose"
|
|
)
|
|
|
|
// MermaidGenerator renders a Mermaid diagram from a compiled Eino graph (Graph/Chain/Workflow).
|
|
//
|
|
// Core concepts and mapping:
|
|
// - Nodes: labeled with their key and component type. Lambda nodes use rounded shapes.
|
|
// - Special nodes: START/END are rendered with safe IDs (start_node/end_node) to avoid Mermaid keyword conflicts.
|
|
// - SubGraphs: nested Graph/Chain/Workflow are rendered as Mermaid sub-graphs with their component type in the title.
|
|
// - Edges:
|
|
// - In general graphs/chains: a single solid arrow (-->), representing standard control+data execution.
|
|
// - In workflows (workflowStyle=true): edges are distinguished by semantics:
|
|
// - control+data: normal arrow with label "control+data" ("-- control+data -->")
|
|
// - control-only: bold arrow with label "control-only" ("== control-only ==>")
|
|
// - data-only: dotted arrow with label "data-only" ("-. data-only .->")
|
|
// Branch decision diamonds and their incoming/outgoing edges are treated as control-only in workflows.
|
|
//
|
|
// Usage:
|
|
//
|
|
// buf := &bytes.Buffer{}
|
|
// gen := visualize.NewMermaidGenerator(buf) // for Graph/Chain
|
|
// // or
|
|
// gen := visualize.NewMermaidGeneratorWorkflow(buf) // for Workflow with labeled edges
|
|
// _, _ = g.Compile(ctx, compose.WithGraphCompileCallbacks(gen), compose.WithGraphName("MyGraph"))
|
|
// // Write to a Markdown file:
|
|
// md := "```mermaid\n" + buf.String() + "\n```\n"
|
|
// _ = os.WriteFile("my_graph.md", []byte(md), 0644)
|
|
type MermaidGenerator struct {
|
|
w io.Writer
|
|
workflowStyle bool
|
|
autoWrite bool
|
|
outDir string
|
|
baseName string
|
|
makeImages bool
|
|
}
|
|
|
|
// NewMermaidGenerator creates a generator that auto-writes Markdown and attempts PNG/SVG generation.
|
|
// If dir is empty, current working directory is used. File name is derived from graph name or defaults to "topology".
|
|
func NewMermaidGenerator(dir string) *MermaidGenerator {
|
|
return &MermaidGenerator{autoWrite: true, outDir: dir, makeImages: true}
|
|
}
|
|
|
|
// OnFinish is the compile callback entrypoint invoked by Eino after graph compilation.
|
|
// It reads the compile-time GraphInfo and writes a complete Mermaid diagram to the writer.
|
|
func (m *MermaidGenerator) OnFinish(_ context.Context, info *compose.GraphInfo) {
|
|
m.generate(info)
|
|
}
|
|
|
|
// generate orchestrates diagram construction by delegating to renderGraph.
|
|
// The top-level direction is TD (top-down) for readability and consistency.
|
|
func (m *MermaidGenerator) generate(info *compose.GraphInfo) {
|
|
isWorkflow := m.workflowStyle
|
|
if !isWorkflow {
|
|
if len(info.Edges) > len(info.DataEdges) {
|
|
isWorkflow = true
|
|
}
|
|
|
|
if !isWorkflow {
|
|
for from, edges := range info.Edges {
|
|
dataEdges, ok := info.DataEdges[from]
|
|
if !ok {
|
|
isWorkflow = true
|
|
break
|
|
}
|
|
|
|
if len(edges) != len(dataEdges) {
|
|
isWorkflow = true
|
|
break
|
|
}
|
|
|
|
for i := range edges {
|
|
edge := edges[i]
|
|
found := false
|
|
for _, dEdge := range dataEdges {
|
|
if dEdge == edge {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
isWorkflow = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
sb := &strings.Builder{}
|
|
sb.WriteString("graph TD\n")
|
|
m.renderGraph(sb, info, "", 1, isWorkflow)
|
|
if m.w != nil && !m.autoWrite {
|
|
_, _ = fmt.Fprint(m.w, sb.String())
|
|
return
|
|
}
|
|
|
|
dir := m.outDir
|
|
if dir == "" {
|
|
if wd, err := os.Getwd(); err == nil {
|
|
dir = wd
|
|
} else {
|
|
dir = "."
|
|
}
|
|
}
|
|
name := m.baseName
|
|
if name == "" {
|
|
if len(info.Name) > 0 {
|
|
name = sanitize(info.Name)
|
|
} else {
|
|
name = "topology"
|
|
}
|
|
}
|
|
mdPath := filepath.Join(dir, name+".md")
|
|
content := sb.String()
|
|
_ = os.WriteFile(mdPath, []byte("```mermaid\n"+content+"\n```"), 0o644)
|
|
if m.makeImages {
|
|
mmdPath := filepath.Join(dir, name+".mmd")
|
|
_ = os.WriteFile(mmdPath, []byte(content), 0o644)
|
|
m.renderImage(mmdPath, filepath.Join(dir, name+".png"))
|
|
_ = os.Remove(mmdPath)
|
|
}
|
|
}
|
|
|
|
// renderGraph builds a Mermaid diagram section for the given GraphInfo.
|
|
// It:
|
|
// 1) Collects and sorts node keys for deterministic output
|
|
// 2) Renders nodes (including nested sub-graphs)
|
|
// 3) Renders control/data edges with workflow-aware styles
|
|
// 4) Renders branches with decision diamonds and proper edge types
|
|
func (m *MermaidGenerator) renderGraph(sb *strings.Builder, info *compose.GraphInfo, prefix string, indentLevel int, style bool) {
|
|
indent := strings.Repeat(" ", indentLevel)
|
|
|
|
// Collect all nodes from info.Nodes, Edges, and Branches
|
|
allNodes := make(map[string]bool)
|
|
for k := range info.Nodes {
|
|
allNodes[k] = true
|
|
}
|
|
for start, ends := range info.Edges {
|
|
allNodes[start] = true
|
|
for _, end := range ends {
|
|
allNodes[end] = true
|
|
}
|
|
}
|
|
for start, branches := range info.Branches {
|
|
allNodes[start] = true
|
|
for _, branch := range branches {
|
|
endNodes := branch.GetEndNode()
|
|
for end := range endNodes {
|
|
allNodes[end] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort nodes for deterministic output
|
|
nodes := make([]string, 0, len(allNodes))
|
|
for k := range allNodes {
|
|
nodes = append(nodes, k)
|
|
}
|
|
sort.Strings(nodes)
|
|
|
|
// Render Nodes
|
|
for _, nodeKey := range nodes {
|
|
nodeID := m.nodeID(prefix, nodeKey)
|
|
|
|
if nodeInfo, ok := info.Nodes[nodeKey]; ok {
|
|
if nodeInfo.GraphInfo != nil {
|
|
// Subgraph
|
|
subgraphLabel := nodeKey
|
|
switch nodeInfo.Component {
|
|
case compose.ComponentOfChain:
|
|
subgraphLabel = fmt.Sprintf("%s (Chain)", nodeKey)
|
|
case compose.ComponentOfWorkflow:
|
|
subgraphLabel = fmt.Sprintf("%s (Workflow)", nodeKey)
|
|
case compose.ComponentOfGraph:
|
|
subgraphLabel = fmt.Sprintf("%s (Graph)", nodeKey)
|
|
}
|
|
_, _ = fmt.Fprintf(sb, "%ssubgraph %s [\"%s\"]\n", indent, nodeID, subgraphLabel)
|
|
childStyle := style
|
|
switch nodeInfo.Component {
|
|
case compose.ComponentOfWorkflow:
|
|
childStyle = true
|
|
case compose.ComponentOfGraph, compose.ComponentOfChain:
|
|
// for explicit Graph/Chain sub-graphs, do not apply workflow styling
|
|
childStyle = false
|
|
}
|
|
m.renderGraph(sb, nodeInfo.GraphInfo, nodeID+"_", indentLevel+1, childStyle)
|
|
_, _ = fmt.Fprintf(sb, "%send\n", indent)
|
|
} else {
|
|
// Regular Node
|
|
shapeStart, shapeEnd := "[", "]"
|
|
if nodeInfo.Component == compose.ComponentOfLambda {
|
|
shapeStart, shapeEnd = "(", ")"
|
|
}
|
|
|
|
label := fmt.Sprintf("%s<br/>(%s)", nodeKey, nodeInfo.Component)
|
|
_, _ = fmt.Fprintf(sb, "%s%s%s\"%s\"%s\n", indent, nodeID, shapeStart, label, shapeEnd)
|
|
}
|
|
} else if nodeKey == compose.START || nodeKey == compose.END {
|
|
// Special nodes: avoid reserved keyword conflict with 'end'
|
|
var safeID string
|
|
if nodeKey == compose.START {
|
|
safeID = m.nodeID(prefix, "start_node")
|
|
} else {
|
|
safeID = m.nodeID(prefix, "end_node")
|
|
}
|
|
_, _ = fmt.Fprintf(sb, "%s%s([%s])\n", indent, safeID, nodeKey)
|
|
}
|
|
}
|
|
|
|
// Render Control Edges
|
|
// Sort edges for deterministic output
|
|
startNodes := make([]string, 0, len(info.Edges))
|
|
for k := range info.Edges {
|
|
startNodes = append(startNodes, k)
|
|
}
|
|
sort.Strings(startNodes)
|
|
|
|
for _, start := range startNodes {
|
|
ends := info.Edges[start]
|
|
for _, end := range ends {
|
|
startID := m.nodeID(prefix, start)
|
|
endID := m.nodeID(prefix, end)
|
|
if start == compose.START {
|
|
startID = m.nodeID(prefix, "start_node")
|
|
}
|
|
if end == compose.END {
|
|
endID = m.nodeID(prefix, "end_node")
|
|
}
|
|
// Determine edge semantics by checking if a matching data edge exists.
|
|
hasData := false
|
|
if des, ok := info.DataEdges[start]; ok {
|
|
for _, de := range des {
|
|
if de == end {
|
|
hasData = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if style {
|
|
if hasData {
|
|
_, _ = fmt.Fprintf(sb, "%s%s -- control+data --> %s\n", indent, startID, endID)
|
|
} else {
|
|
_, _ = fmt.Fprintf(sb, "%s%s == control-only ==> %s\n", indent, startID, endID)
|
|
}
|
|
} else {
|
|
_, _ = fmt.Fprintf(sb, "%s%s --> %s\n", indent, startID, endID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Render Data Edges
|
|
// Only render if they differ from control edges; otherwise already represented as control+data.
|
|
dataStartNodes := make([]string, 0, len(info.DataEdges))
|
|
for k := range info.DataEdges {
|
|
dataStartNodes = append(dataStartNodes, k)
|
|
}
|
|
sort.Strings(dataStartNodes)
|
|
|
|
for _, start := range dataStartNodes {
|
|
ends := info.DataEdges[start]
|
|
for _, end := range ends {
|
|
// Check if this edge already exists as a control edge
|
|
alreadyExists := false
|
|
for _, controlEnd := range info.Edges[start] {
|
|
if controlEnd == end {
|
|
alreadyExists = true
|
|
break
|
|
}
|
|
}
|
|
if !alreadyExists {
|
|
startID := m.nodeID(prefix, start)
|
|
endID := m.nodeID(prefix, end)
|
|
if start == compose.START {
|
|
startID = m.nodeID(prefix, "start_node")
|
|
}
|
|
if end == compose.END {
|
|
endID = m.nodeID(prefix, "end_node")
|
|
}
|
|
if style {
|
|
_, _ = fmt.Fprintf(sb, "%s%s -. data-only .-> %s\n", indent, startID, endID)
|
|
} else {
|
|
_, _ = fmt.Fprintf(sb, "%s%s -.-> %s\n", indent, startID, endID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Render Branches
|
|
branchStarts := make([]string, 0, len(info.Branches))
|
|
for k := range info.Branches {
|
|
branchStarts = append(branchStarts, k)
|
|
}
|
|
sort.Strings(branchStarts)
|
|
|
|
for _, start := range branchStarts {
|
|
branches := info.Branches[start]
|
|
for i, branch := range branches {
|
|
// Branch decision node (diamond)
|
|
// We need a unique ID for the decision point if there are multiple branches from the same node?
|
|
// Actually, `info.Branches` maps startNode -> []GraphBranch.
|
|
// Usually a node has one set of branches.
|
|
// Let's represent the branch condition as a diamond.
|
|
|
|
// If there are multiple branches, they might be parallel or sequential conditions.
|
|
// Eino `AddBranch` adds a branch.
|
|
|
|
// For visualization, maybe we just draw arrows from startNode to endNodes with a label?
|
|
// Or introduce a "decision" node?
|
|
|
|
// Decision node visualization: startNode -> decision{branch} -> endNodes
|
|
|
|
decisionID := fmt.Sprintf("%s_branch_%d", m.nodeID(prefix, start), i)
|
|
_, _ = fmt.Fprintf(sb, "%s%s{\"%s\"}\n", indent, decisionID, "branch")
|
|
startID := m.nodeID(prefix, start)
|
|
if start == compose.START {
|
|
startID = m.nodeID(prefix, "start_node")
|
|
}
|
|
if style {
|
|
_, _ = fmt.Fprintf(sb, "%s%s ==> %s\n", indent, startID, decisionID)
|
|
} else {
|
|
_, _ = fmt.Fprintf(sb, "%s%s --> %s\n", indent, startID, decisionID)
|
|
}
|
|
|
|
// Sort end nodes
|
|
endNodesMap := branch.GetEndNode()
|
|
endNodes := make([]string, 0, len(endNodesMap))
|
|
for k := range endNodesMap {
|
|
endNodes = append(endNodes, k)
|
|
}
|
|
sort.Strings(endNodes)
|
|
|
|
for _, end := range endNodes {
|
|
endID := m.nodeID(prefix, end)
|
|
if end == compose.END {
|
|
endID = m.nodeID(prefix, "end_node")
|
|
}
|
|
if style {
|
|
_, _ = fmt.Fprintf(sb, "%s%s ==> %s\n", indent, decisionID, endID)
|
|
} else {
|
|
_, _ = fmt.Fprintf(sb, "%s%s --> %s\n", indent, decisionID, endID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// nodeID sanitizes a node key to be a valid Mermaid identifier, and adds a caller-provided prefix
|
|
// to ensure uniqueness when rendering nested graphs.
|
|
func (m *MermaidGenerator) nodeID(prefix, key string) string {
|
|
// Sanitize key for Mermaid ID
|
|
safeKey := strings.ReplaceAll(key, " ", "_")
|
|
safeKey = strings.ReplaceAll(safeKey, "-", "_")
|
|
return prefix + safeKey
|
|
}
|
|
|
|
func (m *MermaidGenerator) renderImage(input, output string) {
|
|
if _, err := exec.LookPath("mmdc"); err != nil {
|
|
// fallback to chromedp rendering
|
|
data, rErr := os.ReadFile(input)
|
|
if rErr != nil {
|
|
return
|
|
}
|
|
_ = renderWithChromedp(string(data), output)
|
|
return
|
|
}
|
|
cmd := exec.Command("mmdc", "-i", input, "-o", output)
|
|
_ = cmd.Run()
|
|
}
|
|
|
|
func sanitize(s string) string {
|
|
s = strings.TrimSpace(s)
|
|
s = strings.ReplaceAll(s, " ", "_")
|
|
s = strings.ReplaceAll(s, "/", "_")
|
|
s = strings.ReplaceAll(s, "\\", "_")
|
|
return s
|
|
}
|
|
|
|
func renderWithChromedp(mermaidCode, output string) error {
|
|
html := buildMermaidHTML(mermaidCode)
|
|
ctx, cancel := chromedp.NewContext(context.Background())
|
|
defer cancel()
|
|
|
|
var err error
|
|
ext := strings.ToLower(filepath.Ext(output))
|
|
switch ext {
|
|
case ".png":
|
|
var buf []byte
|
|
err = chromedp.Run(ctx,
|
|
chromedp.Navigate("data:text/html,"+urlEncode(html)),
|
|
chromedp.WaitVisible(`#container svg`, chromedp.ByQuery),
|
|
chromedp.Screenshot(`#container svg`, &buf, chromedp.NodeVisible, chromedp.ByQuery),
|
|
)
|
|
if err == nil {
|
|
err = os.WriteFile(output, buf, 0o644)
|
|
}
|
|
case ".svg":
|
|
var svg string
|
|
err = chromedp.Run(ctx,
|
|
chromedp.Navigate("data:text/html,"+urlEncode(html)),
|
|
chromedp.WaitVisible(`#container svg`, chromedp.ByQuery),
|
|
chromedp.OuterHTML(`#container svg`, &svg, chromedp.ByQuery),
|
|
)
|
|
if err == nil {
|
|
err = os.WriteFile(output, []byte(svg), 0o644)
|
|
}
|
|
default:
|
|
// unsupported extension
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func buildMermaidHTML(code string) string {
|
|
return `<!doctype html>
|
|
<html>
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<script src="https://unpkg.com/mermaid@10/dist/mermaid.min.js"></script>
|
|
<style>body{margin:0;padding:20px;background:#fff}</style>
|
|
<script>mermaid.initialize({startOnLoad:true});</script>
|
|
</head>
|
|
<body>
|
|
<div id="container" class="mermaid">` + code + `</div>
|
|
</body>
|
|
</html>`
|
|
}
|
|
|
|
func urlEncode(s string) string { // minimal percent-encoding for data URL
|
|
r := strings.NewReplacer("%", "%25", "#", "%23", "\n", "%0A", "\r", "%0D")
|
|
return r.Replace(s)
|
|
}
|