You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
4.3 KiB
Go

/*
* Copyright 2025 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mem
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/cloudwego/eino/schema"
)
func GetDefaultMemory() *SimpleMemory {
return NewSimpleMemory(SimpleMemoryConfig{
Dir: "data/memory",
MaxWindowSize: 6,
})
}
type SimpleMemoryConfig struct {
Dir string
MaxWindowSize int
}
func NewSimpleMemory(cfg SimpleMemoryConfig) *SimpleMemory {
if cfg.Dir == "" {
cfg.Dir = "/tmp/eino/memory"
}
if err := os.MkdirAll(cfg.Dir, 0755); err != nil {
return nil
}
return &SimpleMemory{
dir: cfg.Dir,
maxWindowSize: cfg.MaxWindowSize,
conversations: make(map[string]*Conversation),
}
}
// simple memory can store messages of each conversation
type SimpleMemory struct {
mu sync.Mutex
dir string
maxWindowSize int
conversations map[string]*Conversation
}
func (m *SimpleMemory) GetConversation(id string, createIfNotExist bool) *Conversation {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.conversations[id]
filePath := filepath.Join(m.dir, id+".jsonl")
if !ok {
if _, err := os.Stat(filePath); os.IsNotExist(err) {
if createIfNotExist {
if err := os.WriteFile(filePath, []byte(""), 0644); err != nil {
return nil
}
m.conversations[id] = &Conversation{
ID: id,
Messages: make([]*schema.Message, 0),
filePath: filePath,
maxWindowSize: m.maxWindowSize,
}
}
}
con := &Conversation{
ID: id,
Messages: make([]*schema.Message, 0),
filePath: filePath,
maxWindowSize: m.maxWindowSize,
}
con.load()
m.conversations[id] = con
}
return m.conversations[id]
}
func (m *SimpleMemory) ListConversations() []string {
m.mu.Lock()
defer m.mu.Unlock()
files, err := os.ReadDir(m.dir)
if err != nil {
return nil
}
ids := make([]string, 0, len(files))
for _, file := range files {
if file.IsDir() {
continue
}
ids = append(ids, strings.TrimSuffix(file.Name(), ".jsonl"))
}
return ids
}
func (m *SimpleMemory) DeleteConversation(id string) error {
m.mu.Lock()
defer m.mu.Unlock()
filePath := filepath.Join(m.dir, id+".jsonl")
if err := os.Remove(filePath); err != nil {
return fmt.Errorf("failed to delete file: %w", err)
}
delete(m.conversations, id)
return nil
}
type Conversation struct {
mu sync.Mutex
ID string `json:"id"`
Messages []*schema.Message `json:"messages"`
filePath string
maxWindowSize int
}
func (c *Conversation) Append(msg *schema.Message) {
c.mu.Lock()
defer c.mu.Unlock()
c.Messages = append(c.Messages, msg)
c.save(msg)
}
func (c *Conversation) GetFullMessages() []*schema.Message {
c.mu.Lock()
defer c.mu.Unlock()
return c.Messages
}
// get messages with max window size
func (c *Conversation) GetMessages() []*schema.Message {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.Messages) > c.maxWindowSize {
return c.Messages[len(c.Messages)-c.maxWindowSize:]
}
return c.Messages
}
func (c *Conversation) load() error {
reader, err := os.Open(c.filePath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer reader.Close()
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
line := scanner.Text()
var msg schema.Message
if err := json.Unmarshal([]byte(line), &msg); err != nil {
return fmt.Errorf("failed to unmarshal message: %w", err)
}
c.Messages = append(c.Messages, &msg)
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("scanner error: %w", err)
}
return nil
}
func (c *Conversation) save(msg *schema.Message) {
str, _ := json.Marshal(msg)
// Append to file
f, err := os.OpenFile(c.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return
}
defer f.Close()
f.Write(str)
f.WriteString("\n")
}