You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
4.3 KiB
Go
209 lines
4.3 KiB
Go
/*
|
|
* Copyright 2025 CloudWeGo Authors
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package mem
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
func GetDefaultMemory() *SimpleMemory {
|
|
return NewSimpleMemory(SimpleMemoryConfig{
|
|
Dir: "data/memory",
|
|
MaxWindowSize: 6,
|
|
})
|
|
}
|
|
|
|
type SimpleMemoryConfig struct {
|
|
Dir string
|
|
MaxWindowSize int
|
|
}
|
|
|
|
func NewSimpleMemory(cfg SimpleMemoryConfig) *SimpleMemory {
|
|
if cfg.Dir == "" {
|
|
cfg.Dir = "/tmp/eino/memory"
|
|
}
|
|
if err := os.MkdirAll(cfg.Dir, 0755); err != nil {
|
|
return nil
|
|
}
|
|
|
|
return &SimpleMemory{
|
|
dir: cfg.Dir,
|
|
maxWindowSize: cfg.MaxWindowSize,
|
|
conversations: make(map[string]*Conversation),
|
|
}
|
|
}
|
|
|
|
// simple memory can store messages of each conversation
|
|
type SimpleMemory struct {
|
|
mu sync.Mutex
|
|
dir string
|
|
maxWindowSize int
|
|
conversations map[string]*Conversation
|
|
}
|
|
|
|
func (m *SimpleMemory) GetConversation(id string, createIfNotExist bool) *Conversation {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
_, ok := m.conversations[id]
|
|
|
|
filePath := filepath.Join(m.dir, id+".jsonl")
|
|
if !ok {
|
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
|
if createIfNotExist {
|
|
if err := os.WriteFile(filePath, []byte(""), 0644); err != nil {
|
|
return nil
|
|
}
|
|
m.conversations[id] = &Conversation{
|
|
ID: id,
|
|
Messages: make([]*schema.Message, 0),
|
|
filePath: filePath,
|
|
maxWindowSize: m.maxWindowSize,
|
|
}
|
|
}
|
|
}
|
|
|
|
con := &Conversation{
|
|
ID: id,
|
|
Messages: make([]*schema.Message, 0),
|
|
filePath: filePath,
|
|
maxWindowSize: m.maxWindowSize,
|
|
}
|
|
con.load()
|
|
m.conversations[id] = con
|
|
}
|
|
|
|
return m.conversations[id]
|
|
}
|
|
|
|
func (m *SimpleMemory) ListConversations() []string {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
files, err := os.ReadDir(m.dir)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
ids := make([]string, 0, len(files))
|
|
for _, file := range files {
|
|
if file.IsDir() {
|
|
continue
|
|
}
|
|
ids = append(ids, strings.TrimSuffix(file.Name(), ".jsonl"))
|
|
}
|
|
|
|
return ids
|
|
}
|
|
|
|
func (m *SimpleMemory) DeleteConversation(id string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
filePath := filepath.Join(m.dir, id+".jsonl")
|
|
if err := os.Remove(filePath); err != nil {
|
|
return fmt.Errorf("failed to delete file: %w", err)
|
|
}
|
|
|
|
delete(m.conversations, id)
|
|
return nil
|
|
}
|
|
|
|
type Conversation struct {
|
|
mu sync.Mutex
|
|
|
|
ID string `json:"id"`
|
|
Messages []*schema.Message `json:"messages"`
|
|
|
|
filePath string
|
|
|
|
maxWindowSize int
|
|
}
|
|
|
|
func (c *Conversation) Append(msg *schema.Message) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.Messages = append(c.Messages, msg)
|
|
|
|
c.save(msg)
|
|
}
|
|
|
|
func (c *Conversation) GetFullMessages() []*schema.Message {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
return c.Messages
|
|
}
|
|
|
|
// get messages with max window size
|
|
func (c *Conversation) GetMessages() []*schema.Message {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if len(c.Messages) > c.maxWindowSize {
|
|
return c.Messages[len(c.Messages)-c.maxWindowSize:]
|
|
}
|
|
|
|
return c.Messages
|
|
}
|
|
|
|
func (c *Conversation) load() error {
|
|
reader, err := os.Open(c.filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open file: %w", err)
|
|
}
|
|
defer reader.Close()
|
|
|
|
scanner := bufio.NewScanner(reader)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
var msg schema.Message
|
|
if err := json.Unmarshal([]byte(line), &msg); err != nil {
|
|
return fmt.Errorf("failed to unmarshal message: %w", err)
|
|
}
|
|
c.Messages = append(c.Messages, &msg)
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return fmt.Errorf("scanner error: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conversation) save(msg *schema.Message) {
|
|
str, _ := json.Marshal(msg)
|
|
|
|
// Append to file
|
|
f, err := os.OpenFile(c.filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer f.Close()
|
|
f.Write(str)
|
|
f.WriteString("\n")
|
|
}
|