feat(react): demonstrate short-term memory
Change-Id: Ib2b796a7f66ee33b9115c9331489c8d86e42b8c5drew/english
parent
d35371e0af
commit
77bd95ba71
@ -0,0 +1,97 @@
|
||||
# Short-Term Memory for ReAct Agent
|
||||
|
||||
This example demonstrates a minimal short-term memory for a `flow/react` agent:
|
||||
|
||||
1. Run the agent with a new input message list, get the assistant output.
|
||||
2. Serialize and persist the original input messages plus the assistant output.
|
||||
3. On the next run, restore the stored messages, append the new input, and continue the conversation.
|
||||
4. Do not persist the system message; inject it at runtime via `MessageModifier`.
|
||||
5. Storage options include an in-memory map and Redis (with optional in-memory `miniredis`).
|
||||
|
||||
## Where to Look
|
||||
|
||||
- `main.go` — minimal demo: two turns that share memory and a system prompt injected at runtime.
|
||||
- `memory/store.go` — `MemoryStore` interface and Gob encode/decode helpers.
|
||||
- `memory/inmem.go` — in-memory store.
|
||||
- `memory/redis.go` — Redis-backed store and `NewMiniRedisClient()` for an embedded Redis server.
|
||||
|
||||
## System Prompt Handling
|
||||
|
||||
- Do not persist system messages.
|
||||
- Use the agent hook `react.AgentConfig.MessageModifier` to prepend the system prompt at execution time:
|
||||
|
||||
```go
|
||||
agent, _ := react.NewAgent(ctx, &react.AgentConfig{
|
||||
Model: model,
|
||||
MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message {
|
||||
return append([]*schema.Message{schema.SystemMessage(sys)}, input...)
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Serialization
|
||||
|
||||
- Messages are serialized using `encoding/gob`.
|
||||
- Eino registers the necessary types, so no manual `gob.Register` is required here.
|
||||
|
||||
## Quick Start (OpenAI)
|
||||
|
||||
Environment variables:
|
||||
|
||||
- `OPENAI_API_KEY`
|
||||
- `OPENAI_MODEL` (e.g., `gpt-4o-mini`)
|
||||
- `OPENAI_BASE_URL` (optional for proxy endpoints)
|
||||
- `OPENAI_BY_AZURE`
|
||||
|
||||
Build and run:
|
||||
|
||||
```bash
|
||||
cd flow/agent/react/memory_example
|
||||
go build -o memory_example main.go
|
||||
./memory_example
|
||||
```
|
||||
|
||||
Expected output:
|
||||
|
||||
- First run: prints assistant response; memory stores the turn.
|
||||
- Second run: restores prior messages, appends the new input, and maintains context.
|
||||
|
||||
## Switch Storage Implementations
|
||||
|
||||
Use the in-memory store (default):
|
||||
|
||||
```go
|
||||
store := memory.NewInMemoryStore()
|
||||
```
|
||||
|
||||
Use Redis with in-memory `miniredis`:
|
||||
|
||||
```go
|
||||
cli, closer, _ := memory.NewMiniRedisClient()
|
||||
defer closer()
|
||||
store := memory.NewRedisStore(cli)
|
||||
```
|
||||
|
||||
Use Redis with a real server:
|
||||
|
||||
```go
|
||||
cli := redis.NewClient(&redis.Options{Addr: "localhost:6379"})
|
||||
store := memory.NewRedisStore(cli)
|
||||
```
|
||||
|
||||
## Minimal Flow
|
||||
|
||||
```go
|
||||
sessionID := "session:demo"
|
||||
prev, _ := store.Read(ctx, sessionID)
|
||||
effective := append(prev, schema.UserMessage(userInput))
|
||||
resp, _ := agent.Generate(ctx, effective)
|
||||
_ = store.Write(ctx, sessionID, append(effective, resp))
|
||||
|
||||
hits, _ := store.Query(ctx, sessionID, "CloudWeGo", 3)
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The example uses `Generate`. You can use `Stream` similarly and persist on `io.EOF`.
|
||||
- Keep the memory window small to cap serialization size; this can be enforced by your store implementation.
|
||||
@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/flow/agent/react"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
|
||||
"github.com/cloudwego/eino-examples/flow/agent/react/memory_example/memory"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
modelName := os.Getenv("OPENAI_MODEL")
|
||||
baseURL := os.Getenv("OPENAI_BASE_URL")
|
||||
isAzure := os.Getenv("OPENAI_BY_AZURE") == "true"
|
||||
|
||||
model, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{APIKey: apiKey, Model: modelName, BaseURL: baseURL, ByAzure: isAzure})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// System prompt is injected at runtime and not persisted.
|
||||
sys := "You are a concise assistant. Maintain context across turns."
|
||||
|
||||
agent, err := react.NewAgent(ctx, &react.AgentConfig{
|
||||
Model: model,
|
||||
MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message {
|
||||
return append([]*schema.Message{schema.SystemMessage(sys)}, input...)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Choose your store: InMemoryStore (default) or RedisStore (see README).
|
||||
store := memory.NewInMemoryStore()
|
||||
sessionID := "session:demo"
|
||||
|
||||
verifyGobRoundTrip()
|
||||
|
||||
run := func(turn string) {
|
||||
// 1) restore prior messages, 2) append new input, 3) call agent, 4) persist with output
|
||||
prev, _ := store.Read(ctx, sessionID)
|
||||
eff := append(prev, schema.UserMessage(turn))
|
||||
msg, err := agent.Generate(ctx, eff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("history_before=%d after=%d\n", len(prev), len(eff)+1)
|
||||
fmt.Println(msg.Content)
|
||||
_ = store.Write(ctx, sessionID, append(eff, msg))
|
||||
|
||||
hits, _ := store.Query(ctx, sessionID, "AI", 3)
|
||||
fmt.Printf("query_hits=%d\n", len(hits))
|
||||
}
|
||||
|
||||
run("Hello, summarize AI briefly.")
|
||||
run("Add two more details.")
|
||||
}
|
||||
|
||||
func verifyGobRoundTrip() {
|
||||
msgs := []*schema.Message{
|
||||
schema.UserMessage("a"),
|
||||
schema.AssistantMessage("b", nil),
|
||||
}
|
||||
// Round-trip serialize/deserialize to validate gob setup.
|
||||
b, err := memory.EncodeMessages(msgs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
out, err := memory.DecodeMessages(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("gob_round_trip=%d\n", len(out))
|
||||
}
|
||||
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// InMemoryStore keeps serialized messages in a process-local map.
|
||||
// Suitable for demos/tests; not shared across processes.
|
||||
type InMemoryStore struct {
|
||||
mu sync.RWMutex
|
||||
data map[string][]byte
|
||||
}
|
||||
|
||||
func NewInMemoryStore() *InMemoryStore {
|
||||
return &InMemoryStore{data: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
// Write encodes and stores messages for the given key.
|
||||
func (s *InMemoryStore) Write(ctx context.Context, sessionID string, msgs []*schema.Message) error {
|
||||
b, err := EncodeMessages(msgs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.data[sessionID] = b
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read returns decoded messages for the given session; returns nil if absent.
|
||||
func (s *InMemoryStore) Read(ctx context.Context, sessionID string) ([]*schema.Message, error) {
|
||||
s.mu.RLock()
|
||||
b := s.data[sessionID]
|
||||
s.mu.RUnlock()
|
||||
return DecodeMessages(b)
|
||||
}
|
||||
|
||||
// Query performs a simple substring search on message contents for the session.
|
||||
func (s *InMemoryStore) Query(ctx context.Context, sessionID string, text string, limit int) ([]*schema.Message, error) {
|
||||
msgs, err := s.Read(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(msgs) == 0 || text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
q := strings.ToLower(text)
|
||||
out := make([]*schema.Message, 0, limit)
|
||||
for _, m := range msgs {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(m.Content), q) {
|
||||
out = append(out, m)
|
||||
if limit > 0 && len(out) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisStore persists serialized messages in Redis under the provided session key.
|
||||
type RedisStore struct {
|
||||
cli *redis.Client
|
||||
}
|
||||
|
||||
func NewRedisStore(cli *redis.Client) *RedisStore {
|
||||
return &RedisStore{cli: cli}
|
||||
}
|
||||
|
||||
// Write encodes and stores messages using Redis SET.
|
||||
func (s *RedisStore) Write(ctx context.Context, sessionID string, msgs []*schema.Message) error {
|
||||
b, err := EncodeMessages(msgs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.cli.Set(ctx, sessionID, b, 0).Err()
|
||||
}
|
||||
|
||||
// Read returns decoded messages from Redis GET; returns nil if not found.
|
||||
func (s *RedisStore) Read(ctx context.Context, sessionID string) ([]*schema.Message, error) {
|
||||
res, err := s.cli.Get(ctx, sessionID).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeMessages(res)
|
||||
}
|
||||
|
||||
func (s *RedisStore) Query(ctx context.Context, sessionID string, text string, limit int) ([]*schema.Message, error) {
|
||||
msgs, err := s.Read(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(msgs) == 0 || text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
out := make([]*schema.Message, 0, limit)
|
||||
q := strings.ToLower(text)
|
||||
for _, m := range msgs {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(m.Content), q) {
|
||||
out = append(out, m)
|
||||
if limit > 0 && len(out) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// NewMiniRedisClient starts an embedded Redis server for local demos/tests.
|
||||
func NewMiniRedisClient() (*redis.Client, func(), error) {
|
||||
srv, err := miniredis.Run()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cli := redis.NewClient(&redis.Options{Addr: srv.Addr()})
|
||||
closer := func() { srv.Close() }
|
||||
return cli, closer, nil
|
||||
}
|
||||
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* Copyright 2025 CloudWeGo Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package memory
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/gob"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// MemoryStore persists and restores short-term conversation history.
|
||||
// Implementations are responsible for storing a slice of messages under a session key.
|
||||
type MemoryStore interface {
|
||||
Write(ctx context.Context, sessionID string, msgs []*schema.Message) error
|
||||
Read(ctx context.Context, sessionID string) ([]*schema.Message, error)
|
||||
Query(ctx context.Context, sessionID string, text string, limit int) ([]*schema.Message, error)
|
||||
}
|
||||
|
||||
// Gob registrations for eino message types are provided by the framework; no manual registration needed here.
|
||||
|
||||
// EncodeMessages serializes messages using Gob.
|
||||
func EncodeMessages(msgs []*schema.Message) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
if err := enc.Encode(msgs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// DecodeMessages deserializes messages previously encoded by EncodeMessages.
|
||||
func DecodeMessages(b []byte) ([]*schema.Message, error) {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
dec := gob.NewDecoder(bytes.NewReader(b))
|
||||
var msgs []*schema.Message
|
||||
if err := dec.Decode(&msgs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
Loading…
Reference in New Issue