feat(react): demonstrate short-term memory
Change-Id: Ib2b796a7f66ee33b9115c9331489c8d86e42b8c5drew/english
parent
d35371e0af
commit
77bd95ba71
@ -0,0 +1,97 @@
|
|||||||
|
# Short-Term Memory for ReAct Agent
|
||||||
|
|
||||||
|
This example demonstrates a minimal short-term memory for a `flow/react` agent:
|
||||||
|
|
||||||
|
1. Run the agent with a new input message list, get the assistant output.
|
||||||
|
2. Serialize and persist the original input messages plus the assistant output.
|
||||||
|
3. On the next run, restore the stored messages, append the new input, and continue the conversation.
|
||||||
|
4. Do not persist the system message; inject it at runtime via `MessageModifier`.
|
||||||
|
5. Storage options include an in-memory map and Redis (with optional in-memory `miniredis`).
|
||||||
|
|
||||||
|
## Where to Look
|
||||||
|
|
||||||
|
- `main.go` — minimal demo: two turns that share memory and a system prompt injected at runtime.
|
||||||
|
- `memory/store.go` — `MemoryStore` interface and Gob encode/decode helpers.
|
||||||
|
- `memory/inmem.go` — in-memory store.
|
||||||
|
- `memory/redis.go` — Redis-backed store and `NewMiniRedisClient()` for an embedded Redis server.
|
||||||
|
|
||||||
|
## System Prompt Handling
|
||||||
|
|
||||||
|
- Do not persist system messages.
|
||||||
|
- Use the agent hook `react.AgentConfig.MessageModifier` to prepend the system prompt at execution time:
|
||||||
|
|
||||||
|
```go
|
||||||
|
agent, _ := react.NewAgent(ctx, &react.AgentConfig{
|
||||||
|
Model: model,
|
||||||
|
MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message {
|
||||||
|
return append([]*schema.Message{schema.SystemMessage(sys)}, input...)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Serialization
|
||||||
|
|
||||||
|
- Messages are serialized using `encoding/gob`.
|
||||||
|
- Eino registers the necessary types, so no manual `gob.Register` is required here.
|
||||||
|
|
||||||
|
## Quick Start (OpenAI)
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
|
||||||
|
- `OPENAI_API_KEY`
|
||||||
|
- `OPENAI_MODEL` (e.g., `gpt-4o-mini`)
|
||||||
|
- `OPENAI_BASE_URL` (optional for proxy endpoints)
|
||||||
|
- `OPENAI_BY_AZURE`
|
||||||
|
|
||||||
|
Build and run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd flow/agent/react/memory_example
|
||||||
|
go build -o memory_example main.go
|
||||||
|
./memory_example
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected output:
|
||||||
|
|
||||||
|
- First run: prints assistant response; memory stores the turn.
|
||||||
|
- Second run: restores prior messages, appends the new input, and maintains context.
|
||||||
|
|
||||||
|
## Switch Storage Implementations
|
||||||
|
|
||||||
|
Use the in-memory store (default):
|
||||||
|
|
||||||
|
```go
|
||||||
|
store := memory.NewInMemoryStore()
|
||||||
|
```
|
||||||
|
|
||||||
|
Use Redis with in-memory `miniredis`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
cli, closer, _ := memory.NewMiniRedisClient()
|
||||||
|
defer closer()
|
||||||
|
store := memory.NewRedisStore(cli)
|
||||||
|
```
|
||||||
|
|
||||||
|
Use Redis with a real server:
|
||||||
|
|
||||||
|
```go
|
||||||
|
cli := redis.NewClient(&redis.Options{Addr: "localhost:6379"})
|
||||||
|
store := memory.NewRedisStore(cli)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Minimal Flow
|
||||||
|
|
||||||
|
```go
|
||||||
|
sessionID := "session:demo"
|
||||||
|
prev, _ := store.Read(ctx, sessionID)
|
||||||
|
effective := append(prev, schema.UserMessage(userInput))
|
||||||
|
resp, _ := agent.Generate(ctx, effective)
|
||||||
|
_ = store.Write(ctx, sessionID, append(effective, resp))
|
||||||
|
|
||||||
|
hits, _ := store.Query(ctx, sessionID, "CloudWeGo", 3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The example uses `Generate`. You can use `Stream` similarly and persist on `io.EOF`.
|
||||||
|
- Keep the memory window small to cap serialization size; this can be enforced by your store implementation.
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
"github.com/cloudwego/eino/flow/agent/react"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-examples/flow/agent/react/memory_example/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||||
|
modelName := os.Getenv("OPENAI_MODEL")
|
||||||
|
baseURL := os.Getenv("OPENAI_BASE_URL")
|
||||||
|
isAzure := os.Getenv("OPENAI_BY_AZURE") == "true"
|
||||||
|
|
||||||
|
model, err := openai.NewChatModel(ctx, &openai.ChatModelConfig{APIKey: apiKey, Model: modelName, BaseURL: baseURL, ByAzure: isAzure})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// System prompt is injected at runtime and not persisted.
|
||||||
|
sys := "You are a concise assistant. Maintain context across turns."
|
||||||
|
|
||||||
|
agent, err := react.NewAgent(ctx, &react.AgentConfig{
|
||||||
|
Model: model,
|
||||||
|
MessageModifier: func(_ context.Context, input []*schema.Message) []*schema.Message {
|
||||||
|
return append([]*schema.Message{schema.SystemMessage(sys)}, input...)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose your store: InMemoryStore (default) or RedisStore (see README).
|
||||||
|
store := memory.NewInMemoryStore()
|
||||||
|
sessionID := "session:demo"
|
||||||
|
|
||||||
|
verifyGobRoundTrip()
|
||||||
|
|
||||||
|
run := func(turn string) {
|
||||||
|
// 1) restore prior messages, 2) append new input, 3) call agent, 4) persist with output
|
||||||
|
prev, _ := store.Read(ctx, sessionID)
|
||||||
|
eff := append(prev, schema.UserMessage(turn))
|
||||||
|
msg, err := agent.Generate(ctx, eff)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("history_before=%d after=%d\n", len(prev), len(eff)+1)
|
||||||
|
fmt.Println(msg.Content)
|
||||||
|
_ = store.Write(ctx, sessionID, append(eff, msg))
|
||||||
|
|
||||||
|
hits, _ := store.Query(ctx, sessionID, "AI", 3)
|
||||||
|
fmt.Printf("query_hits=%d\n", len(hits))
|
||||||
|
}
|
||||||
|
|
||||||
|
run("Hello, summarize AI briefly.")
|
||||||
|
run("Add two more details.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyGobRoundTrip() {
|
||||||
|
msgs := []*schema.Message{
|
||||||
|
schema.UserMessage("a"),
|
||||||
|
schema.AssistantMessage("b", nil),
|
||||||
|
}
|
||||||
|
// Round-trip serialize/deserialize to validate gob setup.
|
||||||
|
b, err := memory.EncodeMessages(msgs)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
out, err := memory.DecodeMessages(b)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("gob_round_trip=%d\n", len(out))
|
||||||
|
}
|
||||||
@ -0,0 +1,81 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InMemoryStore keeps serialized messages in a process-local map.
|
||||||
|
// Suitable for demos/tests; not shared across processes.
|
||||||
|
type InMemoryStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
data map[string][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInMemoryStore() *InMemoryStore {
|
||||||
|
return &InMemoryStore{data: make(map[string][]byte)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write encodes and stores messages for the given key.
|
||||||
|
func (s *InMemoryStore) Write(ctx context.Context, sessionID string, msgs []*schema.Message) error {
|
||||||
|
b, err := EncodeMessages(msgs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.data[sessionID] = b
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read returns decoded messages for the given session; returns nil if absent.
|
||||||
|
func (s *InMemoryStore) Read(ctx context.Context, sessionID string) ([]*schema.Message, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
b := s.data[sessionID]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return DecodeMessages(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query performs a simple substring search on message contents for the session.
|
||||||
|
func (s *InMemoryStore) Query(ctx context.Context, sessionID string, text string, limit int) ([]*schema.Message, error) {
|
||||||
|
msgs, err := s.Read(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(msgs) == 0 || text == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
q := strings.ToLower(text)
|
||||||
|
out := make([]*schema.Message, 0, limit)
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(m.Content), q) {
|
||||||
|
out = append(out, m)
|
||||||
|
if limit > 0 && len(out) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
@ -0,0 +1,91 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
miniredis "github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RedisStore persists serialized messages in Redis under the provided session key.
|
||||||
|
type RedisStore struct {
|
||||||
|
cli *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRedisStore(cli *redis.Client) *RedisStore {
|
||||||
|
return &RedisStore{cli: cli}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write encodes and stores messages using Redis SET.
|
||||||
|
func (s *RedisStore) Write(ctx context.Context, sessionID string, msgs []*schema.Message) error {
|
||||||
|
b, err := EncodeMessages(msgs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.cli.Set(ctx, sessionID, b, 0).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read returns decoded messages from Redis GET; returns nil if not found.
|
||||||
|
func (s *RedisStore) Read(ctx context.Context, sessionID string) ([]*schema.Message, error) {
|
||||||
|
res, err := s.cli.Get(ctx, sessionID).Bytes()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return DecodeMessages(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *RedisStore) Query(ctx context.Context, sessionID string, text string, limit int) ([]*schema.Message, error) {
|
||||||
|
msgs, err := s.Read(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(msgs) == 0 || text == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
out := make([]*schema.Message, 0, limit)
|
||||||
|
q := strings.ToLower(text)
|
||||||
|
for _, m := range msgs {
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(m.Content), q) {
|
||||||
|
out = append(out, m)
|
||||||
|
if limit > 0 && len(out) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMiniRedisClient starts an embedded Redis server for local demos/tests.
|
||||||
|
func NewMiniRedisClient() (*redis.Client, func(), error) {
|
||||||
|
srv, err := miniredis.Run()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
cli := redis.NewClient(&redis.Options{Addr: srv.Addr()})
|
||||||
|
closer := func() { srv.Close() }
|
||||||
|
return cli, closer, nil
|
||||||
|
}
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2025 CloudWeGo Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/gob"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryStore persists and restores short-term conversation history.
|
||||||
|
// Implementations are responsible for storing a slice of messages under a session key.
|
||||||
|
type MemoryStore interface {
|
||||||
|
Write(ctx context.Context, sessionID string, msgs []*schema.Message) error
|
||||||
|
Read(ctx context.Context, sessionID string) ([]*schema.Message, error)
|
||||||
|
Query(ctx context.Context, sessionID string, text string, limit int) ([]*schema.Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gob registrations for eino message types are provided by the framework; no manual registration needed here.
|
||||||
|
|
||||||
|
// EncodeMessages serializes messages using Gob.
|
||||||
|
func EncodeMessages(msgs []*schema.Message) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buf)
|
||||||
|
if err := enc.Encode(msgs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeMessages deserializes messages previously encoded by EncodeMessages.
|
||||||
|
func DecodeMessages(b []byte) ([]*schema.Message, error) {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
dec := gob.NewDecoder(bytes.NewReader(b))
|
||||||
|
var msgs []*schema.Message
|
||||||
|
if err := dec.Decode(&msgs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return msgs, nil
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue