找回密码
 立即注册

4,长期会话记忆

[复制链接]
admin 发表于 4 天前 | 显示全部楼层 |阅读模式
1,定义长期会话管理类

  1. package com.jinhei;
  2. import java.io.*;
  3. import java.nio.file.Files;
  4. import java.nio.file.Paths;
  5. import java.time.LocalDateTime;
  6. import java.time.format.DateTimeFormatter;
  7. import java.util.ArrayList;
  8. import java.util.List;
  9. /**
  10. * 长期记忆管理类
  11. * 用于存储和读取 AI 的对话历史与工具调用记录
  12. * 就像给 AI 配备了一个"笔记本",让它能记住之前发生过的事情
  13. */
  14. public class MemoryManager {
  15.     // 记忆文件路径 - 记忆保存在这里
  16.     private static final String MEMORY_FILE = "agent_memory.json";
  17.     // 内存中的记忆列表
  18.     private List<MemoryEntry> memories;
  19.     /**
  20.      * 记忆条目 - 每一条记忆的结构
  21.      * 包含:角色(用户/AI)、内容、时间戳
  22.      */
  23.     public static class MemoryEntry {
  24.         private String role;           // 角色:user(用户)或 assistant(AI)
  25.         private String content;        // 对话内容
  26.         private String timestamp;      // 时间戳
  27.         private String toolName;       // 使用的工具名称(如果有)
  28.         private String toolResult;     // 工具执行结果(如果有)
  29.         public MemoryEntry(String role, String content) {
  30.             this.role = role;
  31.             this.content = content;
  32.             this.timestamp = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
  33.         }
  34.         // Getter 和 Setter 方法
  35.         public String getRole() { return role; }
  36.         public void setRole(String role) { this.role = role; }
  37.         public String getContent() { return content; }
  38.         public void setContent(String content) { this.content = content; }
  39.         public String getTimestamp() { return timestamp; }
  40.         public void setTimestamp(String timestamp) { this.timestamp = timestamp; }
  41.         public String getToolName() { return toolName; }
  42.         public void setToolName(String toolName) { this.toolName = toolName; }
  43.         public String getToolResult() { return toolResult; }
  44.         public void setToolResult(String toolResult) { this.toolResult = toolResult; }
  45.     }
  46.     /**
  47.      * 构造函数 - 创建记忆管理器时自动加载已有记忆
  48.      */
  49.     public MemoryManager() {
  50.         this.memories = new ArrayList<>();
  51.         loadMemories();  // 从文件加载记忆
  52.     }
  53.     /**
  54.      * 从文件加载记忆
  55.      */
  56.     public void loadMemories() {
  57.         File file = new File(MEMORY_FILE);
  58.         if (!file.exists()) {
  59.             System.out.println("[记忆] 未发现历史记忆文件,从空开始");
  60.             return;
  61.         }
  62.         try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
  63.             StringBuilder content = new StringBuilder();
  64.             String line;
  65.             while ((line = reader.readLine()) != null) {
  66.                 content.append(line);
  67.             }
  68.             // 简单的 JSON 解析(不使用额外库)
  69.             parseMemoriesFromJson(content.toString());
  70.             System.out.println("[记忆] 已加载 " + memories.size() + " 条历史记忆");
  71.         } catch (IOException e) {
  72.             System.err.println("[记忆] 加载失败:" + e.getMessage());
  73.         }
  74.     }
  75.     /**
  76.      * 从 JSON 字符串解析记忆
  77.      * @param json JSON 格式的记忆数据
  78.      */
  79.     private void parseMemoriesFromJson(String json) {
  80.         memories.clear();
  81.         // 简单的 JSON 数组解析
  82.         json = json.trim();
  83.         if (!json.startsWith("[") || !json.endsWith("]")) {
  84.             return;
  85.         }
  86.         // 移除首尾的 [ ]
  87.         json = json.substring(1, json.length() - 1).trim();
  88.         // 分割每个对象
  89.         String[] objects = json.split("\\},\\s*\\{");
  90.         for (String obj : objects) {
  91.             obj = obj.trim();
  92.             if (!obj.startsWith("{")) {
  93.                 obj = "{" + obj;
  94.             }
  95.             if (!obj.endsWith("}")) {
  96.                 obj = obj + "}";
  97.             }
  98.             // 提取各个字段
  99.             String role = extractJsonValue(obj, "role");
  100.             String content = extractJsonValue(obj, "content");
  101.             String timestamp = extractJsonValue(obj, "timestamp");
  102.             String toolName = extractJsonValue(obj, "toolName");
  103.             String toolResult = extractJsonValue(obj, "toolResult");
  104.             if (role != null && content != null) {
  105.                 MemoryEntry entry = new MemoryEntry(role, content);
  106.                 if (timestamp != null) entry.setTimestamp(timestamp);
  107.                 if (toolName != null) entry.setToolName(toolName);
  108.                 if (toolResult != null) entry.setToolResult(toolResult);
  109.                 memories.add(entry);
  110.             }
  111.         }
  112.     }
  113.     /**
  114.      * 添加用户消息到记忆中
  115.      * @param message 用户说的话
  116.      */
  117.     public void addUserMessage(String message) {
  118.         memories.add(new MemoryEntry("user", message));
  119.         System.out.println("[记忆] 已记录用户消息:" + message);
  120.     }
  121.     /**
  122.      * 添加 AI 回复到记忆中
  123.      * @param message AI 的回答
  124.      * @param toolName 使用的工具名称(如果没有则为 null)
  125.      * @param toolResult 工具执行结果(如果没有则为 null)
  126.      */
  127.     public void addAssistantMessage(String message, String toolName, String toolResult) {
  128.         MemoryEntry entry = new MemoryEntry("assistant", message);
  129.         if (toolName != null) {
  130.             entry.setToolName(toolName);
  131.         }
  132.         if (toolResult != null) {
  133.             entry.setToolResult(toolResult);
  134.         }
  135.         memories.add(entry);
  136.         System.out.println("[记忆] 已记录 AI 回复:" + message);
  137.     }
  138.     /**
  139.      * 获取最近的 N 条记忆
  140.      * @param limit 限制数量
  141.      * @return 最近的记忆列表
  142.      */
  143.     public List<MemoryEntry> getRecentMemories(int limit) {
  144.         if (memories.size() <= limit) {
  145.             return memories;
  146.         }
  147.         return memories.subList(memories.size() - limit, memories.size());
  148.     }
  149.     /**
  150.      * 将记忆保存到文件
  151.      * 使用 JSON 格式存储,方便读取和管理
  152.      */
  153.     public void saveMemories() {
  154.         try (FileWriter writer = new FileWriter(MEMORY_FILE)) {
  155.             StringBuilder jsonBuilder = new StringBuilder("[\n");
  156.             for (int i = 0; i < memories.size(); i++) {
  157.                 MemoryEntry entry = memories.get(i);
  158.                 jsonBuilder.append("  {\n");
  159.                 jsonBuilder.append("    "role": "").append(escapeJson(entry.getRole())).append("",\n");
  160.                 jsonBuilder.append("    "content": "").append(escapeJson(entry.getContent())).append("",\n");
  161.                 jsonBuilder.append("    "timestamp": "").append(escapeJson(entry.getTimestamp())).append("",\n");
  162.                 if (entry.getToolName() != null) {
  163.                     jsonBuilder.append("    "toolName": "").append(escapeJson(entry.getToolName())).append("",\n");
  164.                 }
  165.                 if (entry.getToolResult() != null) {
  166.                     jsonBuilder.append("    "toolResult": "").append(escapeJson(entry.getToolResult())).append("",\n");
  167.                 }
  168.                 // 移除最后一个逗号
  169.                 String lastLine = jsonBuilder.substring(jsonBuilder.lastIndexOf(","));
  170.                 jsonBuilder.delete(jsonBuilder.length() - 2, jsonBuilder.length());
  171.                 jsonBuilder.append("\n  }");
  172.                 if (i < memories.size() - 1) {
  173.                     jsonBuilder.append(",");
  174.                 }
  175.                 jsonBuilder.append("\n");
  176.             }
  177.             jsonBuilder.append("]");
  178.             writer.write(jsonBuilder.toString());
  179.             System.out.println("[记忆] 已保存 " + memories.size() + " 条记忆到文件");
  180.         } catch (IOException e) {
  181.             System.err.println("[记忆] 保存失败:" + e.getMessage());
  182.         }
  183.     }
  184.     /**
  185.      * 从 JSON 对象中提取指定字段的值
  186.      * @param json JSON 对象字符串
  187.      * @param key 字段名
  188.      * @return 字段值
  189.      */
  190.     private String extractJsonValue(String json, String key) {
  191.         String searchKey = """ + key + """;
  192.         int keyIndex = json.indexOf(searchKey);
  193.         if (keyIndex == -1) {
  194.             return null;
  195.         }
  196.         int colonIndex = json.indexOf(":", keyIndex);
  197.         if (colonIndex == -1) {
  198.             return null;
  199.         }
  200.         // 跳过冒号后的空格
  201.         int startIndex = colonIndex + 1;
  202.         while (startIndex < json.length() && Character.isWhitespace(json.charAt(startIndex))) {
  203.             startIndex++;
  204.         }
  205.         if (startIndex >= json.length()) {
  206.             return null;
  207.         }
  208.         // 检查是否是字符串值
  209.         if (json.charAt(startIndex) == '"') {
  210.             int endIndex = json.indexOf(""", startIndex + 1);
  211.             if (endIndex == -1) {
  212.                 return null;
  213.             }
  214.             return json.substring(startIndex + 1, endIndex);
  215.         }
  216.         return null;
  217.     }
  218.     /**
  219.      * 转义 JSON 特殊字符
  220.      * @param text 原始文本
  221.      * @return 转义后的文本
  222.      */
  223.     private String escapeJson(String text) {
  224.         if (text == null) return "";
  225.         return text.replace("\", "\\\")
  226.                 .replace(""", "\\"")
  227.                 .replace("\n", "\\n")
  228.                 .replace("\r", "\\r")
  229.                 .replace("\t", "\\t");
  230.     }
  231.     /**
  232.      * 获取记忆统计信息
  233.      * @return 统计信息字符串
  234.      */
  235.     public String getMemoryStats() {
  236.         int userMessages = 0;
  237.         int assistantMessages = 0;
  238.         int toolCalls = 0;
  239.         for (MemoryEntry entry : memories) {
  240.             if ("user".equals(entry.getRole())) {
  241.                 userMessages++;
  242.             } else if ("assistant".equals(entry.getRole())) {
  243.                 assistantMessages++;
  244.                 if (entry.getToolName() != null) {
  245.                     toolCalls++;
  246.                 }
  247.             }
  248.         }
  249.         return String.format("记忆统计:共 %d 条,用户消息 %d 条,AI 回复 %d 条,工具调用 %d 次",
  250.                 memories.size(), userMessages, assistantMessages, toolCalls);
  251.     }
  252. }
复制代码
2,实现会话记忆
  1. package com.jinhei;
  2. import com.alibaba.fastjson2.JSONObject;
  3. import com.openai.client.OpenAIClient;
  4. import com.openai.client.okhttp.OpenAIOkHttpClient;
  5. import com.openai.models.chat.completions.ChatCompletion;
  6. import com.openai.models.chat.completions.ChatCompletionCreateParams;
  7. import java.lang.reflect.InvocationTargetException;
  8. import java.lang.reflect.Method;
  9. import java.util.HashMap;
  10. import java.util.List;
  11. /**
  12. * 调用模型对话
  13. */
  14. public class AiChat {
  15.     private static final String TEMPLATE = """
  16.             你是一位能力强大的 AI 助手,擅长通过逻辑推理与调用工具来解决问题。
  17.             
  18.             你可以使用的工具如下:
  19.             {tools}
  20.             
  21.             **重要:你必须严格按照以下 JSON 格式返回工具调用请求:**
  22.             ```json
  23.             {
  24.               "toolName": "工具名称",
  25.               "params": "{参数 JSON 字符串}"
  26.             }
  27.             ```
  28.             
  29.             注意:
  30.             1. 只能返回上述 JSON 格式,不要添加任何解释文字
  31.             2. params 字段必须是字符串格式的 JSON
  32.             3. 不要使用代码块标记
  33.             
  34.             对话历史:
  35.             {memory}
  36.             
  37.             当前用户的问题是:
  38.             {input}
  39.             """;
  40.     public static void main(String[] args) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
  41.         // 第0步:准备工作 - 注册可用的工具
  42.         // 创建一个工具注册表(就像一个工具箱),记录每个工具的名称和对应的方法
  43.         // key: 工具名称(AI调用时使用)
  44.         // value: 实际执行的Java方法(反射调用)
  45.         // 通俗理解:建立一张"工具对照表"
  46.         // 当AI说"我要用putFile工具"时,程序就知道要调用AgentTools类的putFile方法
  47.         HashMap<String, Method> tools = new HashMap<>();
  48.         tools.put("putFile", AgentTools.class.getMethod("putFile", String.class));
  49.         tools.put("getFile", AgentTools.class.getMethod("getFile", String.class));
  50.         tools.put("batchPutFile", AgentTools.class.getMethod("batchPutFile", String.class));
  51.         // 第 0.5 步:创建记忆管理器 - 加载长期记忆
  52.         MemoryManager memoryManager = new MemoryManager();
  53.         System.out.println(memoryManager.getMemoryStats());
  54.         //第1步:创建AI客户端 - 建立与AI服务的连接通道
  55.         OpenAIClient aiClient = OpenAIOkHttpClient.builder()
  56.                 .apiKey(AiConfig.API_KEY)      // 设置身份证(API密钥)
  57.                 .baseUrl(AiConfig.BASE_URL)    // 设置对方地址(服务地址)
  58.                 .build();                       // 建造完成,电话接通
  59.         //第 2 步:准备要问的问题 - 编写聊天内容
  60.         // 2.1 用户的实际问题(原始输入)
  61. //         String promptString = "批量创建 3 个文件到 D:\\ 中,分别是 a.txt 内容为 Hello A, b.txt 内容为 Hello B, c.txt 内容为 Hello C";
  62. //        String promptString = "读取 D:\\a.txt 文件的内容";  // 这个问题 AI 会直接回答,不需要调用工具
  63.         String promptString = "我之前都让你干了哪些事情";  // 这个问题 AI 会直接回答,不需要调用工具
  64.         // 2.1.5 记录用户消息到记忆中
  65.         memoryManager.addUserMessage(promptString);
  66.         // 2.2 替换模板中的工具占位符
  67.         // ToolUtil.getToolDescription(AgentTools.class) 会返回 AgentTools 类中所有工具的说明
  68.         // 就像把“工具箱清单”填写到工作流程说明书中
  69.         // 告诉 AI:你能用的工具有“写文件”、“读文件”等等...
  70.         String prompt = TEMPLATE.replace("{tools}", ToolUtil.getToolDescription(AgentTools.class));
  71.         // 2.2.5 获取最近的记忆并替换到模板中
  72.         // 这样 AI 就能记住之前的对话内容
  73.         String memoryContext = buildMemoryContext(memoryManager.getRecentMemories(10));
  74.         prompt = prompt.replace("{memory}", memoryContext);
  75.         // 2.3 替换模板中的用户问题占位符
  76.         // 把用户的具体问题填进去
  77.         // 就像告诉AI:"这是你要完成的具体任务"
  78.         prompt = prompt.replace("{input}", promptString);
  79.         //第3步:构建请求参数 - 把问题打包成AI能理解的格式
  80.         ChatCompletionCreateParams params = ChatCompletionCreateParams.builder()
  81.                 .addUserMessage(prompt)           // 添加用户消息:你说的话
  82.                 .model(AiConfig.LLM_NAME)         // 指定用哪个AI大脑来处理
  83.                 .build();                          // 打包完成,准备发送
  84.         //第4步:发送请求并获取回复 - 把信寄出去,等待回信
  85.         ChatCompletion chatCompletion = aiClient.chat()      // 打开聊天功能
  86.                 .completions()   // 开启补全模式(AI回复)
  87.                 .create(params); // 发送请求并等待响应
  88.         // 第5步:提取AI的回答 - 拆开回信,取出内容
  89.         // 5.1 获取AI返回的原始消息
  90.         String message = chatCompletion.choices().get(0).message().content().get();
  91.         // 5.2 清理消息格式 - 移除 markdown 代码块标记和多余文字
  92.         // AI 可能把工具调用放在代码块中,我们需要提取出纯 JSON
  93.         // 例如:将 ```tool_code {...} ``` 转换成 {...}
  94.         message = message.replace("```tool_code", "");
  95.         message = message.replace("```json", "");
  96.         message = message.replace("```", "");
  97.         // 尝试提取第一个 { 到最后一个 } 之间的内容(处理 AI 添加解释文字的情况)
  98.         int firstBrace = message.indexOf('{');
  99.         int lastBrace = message.lastIndexOf('}');
  100.         if (firstBrace != -1 && lastBrace != -1 && lastBrace > firstBrace) {
  101.             message = message.substring(firstBrace, lastBrace + 1);
  102.         }
  103.         // 5.3 解析 JSON - 理解 AI 想要调用哪个工具
  104.         // 将清理后的字符串解析成 JSON 对象,提取工具名称和参数
  105.         // 就像读懂 AI 写的"工具使用申请单"
  106.         JSONObject jsonObject = null;
  107.         try {
  108.             jsonObject = JSONObject.parseObject(message);
  109.             // 检查是否是有效的工具调用格式(必须包含 toolName 和 params)
  110.             String toolName = jsonObject.getString("toolName");
  111.             String toolParams = jsonObject.getString("params");
  112.             if (toolName == null || toolParams == null) {
  113.                 // 格式不对,说明不是工具调用
  114.                 jsonObject = null;
  115.             }
  116.         } catch (Exception e) {
  117.             // 解析失败,说明 AI 返回的不是 JSON,而是直接回答
  118.             System.out.println("\n[AI 直接回复]\n" + message);
  119.             // 记录 AI 的回复到记忆中(没有工具调用)
  120.             memoryManager.addAssistantMessage(message, null, null);
  121.             memoryManager.saveMemories();
  122.             System.out.println(memoryManager.getMemoryStats());
  123.             return; // 直接结束,不执行工具调用
  124.         }
  125.         // 能到这里说明是有效的工具调用
  126.         String toolName = jsonObject.getString("toolName");
  127.         String toolParams = jsonObject.getString("params");
  128.         // 第 6 步:执行工具调用 - 根据 AI 的指令实际操作
  129.         // 6.1 从工具注册表中找到对应的方法
  130.         // 就像根据工具名称从工具箱里取出对应的工具
  131.         Method toolMethod = tools.get(toolName);
  132.         // 6.2 创建工具实例并执行方法
  133.         // 通过反射调用工具方法,就像启动机器开始干活
  134.         // new AgentTools():创建工具对象(准备干活)
  135.         // toolMethod.invoke():执行具体方法(开始干活)
  136.         // toolParams:传入参数(告诉工具要怎么做)
  137.         Object result = toolMethod.invoke(new AgentTools(), toolParams);
  138.         // 第 7 步:显示结果 - 把工具执行的结果打印出来
  139.         System.out.println(result);
  140.         // 第 8 步:保存记忆 - 将本次对话记录保存到长期记忆中
  141.         // 重要:记录自然语言总结,而不是工具调用 JSON
  142.         String assistantSummary = buildAssistantSummary(toolName, result.toString());
  143.         memoryManager.addAssistantMessage(assistantSummary, toolName, result.toString());
  144.         memoryManager.saveMemories();
  145.         System.out.println(memoryManager.getMemoryStats());
  146.     }
  147.     /**
  148.      * 构建 AI 回复的自然语言总结
  149.      * @param toolName 工具名称
  150.      * @param toolResult 工具执行结果
  151.      * @return 自然语言总结
  152.      */
  153.     private static String buildAssistantSummary(String toolName, String toolResult) {
  154.         if ("putFile".equals(toolName)) {
  155.             return "使用写文件工具创建了文件";
  156.         } else if ("getFile".equals(toolName)) {
  157.             return "使用读文件工具读取了文件内容";
  158.         } else if ("batchPutFile".equals(toolName)) {
  159.             // 从结果中提取文件信息
  160.             StringBuilder summary = new StringBuilder("批量创建了以下文件:");
  161.             String[] lines = toolResult.split("\\n");
  162.             for (String line : lines) {
  163.                 if (line.contains("成功写入")) {
  164.                     // 提取文件名
  165.                     int start = line.indexOf("D:");
  166.                     if (start != -1) {
  167.                         String filePath = line.substring(start).trim();
  168.                         summary.append(filePath).append("、");
  169.                     }
  170.                 }
  171.             }
  172.             // 移除最后一个顿号
  173.             if (summary.length() > 0 && summary.charAt(summary.length() - 1) == '、') {
  174.                 summary.setLength(summary.length() - 1);
  175.             }
  176.             return summary.toString();
  177.         }
  178.         return "调用了工具:" + toolName;
  179.     }
  180.     /**
  181.      * 构建记忆上下文字符串
  182.      * @param memories 记忆列表
  183.      * @return 格式化的记忆文本
  184.      */
  185.     private static String buildMemoryContext(List<MemoryManager.MemoryEntry> memories) {
  186.         if (memories.isEmpty()) {
  187.             return "(暂无历史对话)";
  188.         }
  189.         StringBuilder context = new StringBuilder();
  190.         for (MemoryManager.MemoryEntry entry : memories) {
  191.             String roleText = "user".equals(entry.getRole()) ? "用户" : "AI";
  192.             context.append(String.format("[%s] %s: %s\n",
  193.                     entry.getTimestamp(), roleText, entry.getContent()));
  194.             // 如果有工具调用信息,也加上
  195.             if (entry.getToolName() != null) {
  196.                 context.append(String.format("  → 使用工具:%s\n", entry.getToolName()));
  197.             }
  198.             if (entry.getToolResult() != null) {
  199.                 context.append(String.format("  → 工具结果:%s\n", entry.getToolResult()));
  200.             }
  201.         }
  202.         return context.toString();
  203.     }
  204. }
复制代码
aiagent.zip (19.04 KB, 下载次数: 0, 售价: 50 金豆)

QQ|网站地图|Archiver|小黑屋|金黑网络 ( 粤ICP备2021124338号 )

网站建设,微信公众号小程序制作,商城系统开发,高端系统定制,app软件开发,智能物联网开发,直播带货系统等

Powered by Www.Jinhei.Cn

Copyright © 2013-2024 深圳市金黑网络技术有限公司 版权所有

快速回复 返回顶部 返回列表