package tech.aiflowy.ai.controller; import cn.dev33.satoken.annotation.SaIgnore; import com.agentsflex.core.message.AiMessage; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import tech.aiflowy.ai.entity.*; import tech.aiflowy.ai.mapper.AiBotConversationMessageMapper; import tech.aiflowy.ai.service.*; import tech.aiflowy.common.ai.ChatManager; import tech.aiflowy.common.ai.MySseEmitter; import tech.aiflowy.common.domain.Result; import tech.aiflowy.common.util.StringUtil; import tech.aiflowy.common.web.controller.BaseCurdController; import tech.aiflowy.common.web.jsonbody.JsonBody; import tech.aiflowy.common.satoken.util.SaTokenUtil; import cn.hutool.core.util.ObjectUtil; import com.agentsflex.core.llm.ChatContext; import com.agentsflex.core.llm.Llm; import com.agentsflex.core.llm.StreamResponseListener; import com.agentsflex.core.llm.functions.Function; import com.agentsflex.core.llm.response.AiMessageResponse; import com.agentsflex.core.llm.response.FunctionCaller; import com.agentsflex.core.message.HumanMessage; import com.agentsflex.core.message.SystemMessage; import com.agentsflex.core.prompt.HistoriesPrompt; import com.agentsflex.core.util.CollectionUtil; import com.alibaba.fastjson.JSON; import com.mybatisflex.core.query.QueryWrapper; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import javax.annotation.Resource; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.math.BigInteger; import java.util.*; /** * 控制层。 * * @author michael * @since 2024-08-23 */ @RestController @RequestMapping("/api/v1/aiBot") public class AiBotController extends BaseCurdController { private final AiLlmService aiLlmService; private final AiBotWorkflowService aiBotWorkflowService; private final AiBotKnowledgeService aiBotKnowledgeService; private final AiBotMessageService aiBotMessageService; @Resource private AiBotConversationMessageService aiBotConversationMessageService; @Resource private AiBotConversationMessageMapper aiBotConversationMessageMapper; public AiBotController(AiBotService service, AiLlmService aiLlmService, AiBotWorkflowService aiBotWorkflowService, AiBotKnowledgeService aiBotKnowledgeService, AiBotMessageService aiBotMessageService) { super(service); this.aiLlmService = aiLlmService; this.aiBotWorkflowService = aiBotWorkflowService; this.aiBotKnowledgeService = aiBotKnowledgeService; this.aiBotMessageService = aiBotMessageService; } @Resource private AiPluginsService aiPluginsService; @Resource private AiBotPluginsService aiBotPluginsService; @PostMapping("updateOptions") public Result updateOptions(@JsonBody("id") BigInteger id, @JsonBody("options") Map options) { AiBot aiBot = service.getById(id); Map existOptions = aiBot.getOptions(); if (existOptions == null) { existOptions = new HashMap<>(); } if (options != null) { existOptions.putAll(options); } aiBot.setOptions(existOptions); service.updateById(aiBot); return Result.success(); } @PostMapping("updateLlmOptions") public Result updateLlmOptions(@JsonBody("id") BigInteger id, @JsonBody("llmOptions") Map llmOptions) { AiBot aiBot = service.getById(id); Map existLlmOptions = aiBot.getLlmOptions(); if (existLlmOptions == null) { existLlmOptions = new HashMap<>(); } if (llmOptions != null) { existLlmOptions.putAll(llmOptions); } aiBot.setLlmOptions(existLlmOptions); service.updateById(aiBot); return Result.success(); } /** * 当前系统用户调用对话 * @param prompt * @param botId * @param sessionId * @param isExternalMsg * @param response * @return */ @PostMapping("chat") public SseEmitter chat(@JsonBody(value = "prompt", required = true) String prompt, @JsonBody(value = "botId", required = true) BigInteger botId, @JsonBody(value = "sessionId", required = true) String sessionId, @JsonBody(value = "isExternalMsg") int isExternalMsg, HttpServletResponse response) { response.setContentType("text/event-stream"); AiBot aiBot = service.getById(botId); if (aiBot == null) { return ChatManager.getInstance().sseEmitterForContent("机器人不存在"); } Map llmOptions = aiBot.getLlmOptions(); AiLlm aiLlm = aiLlmService.getById(aiBot.getLlmId()); if (aiLlm == null) { return ChatManager.getInstance().sseEmitterForContent("LLM不存在"); } Llm llm = aiLlm.toLlm(); AiBotMessageMemory memory = new AiBotMessageMemory(botId, SaTokenUtil.getLoginAccount().getId(), sessionId, isExternalMsg, aiBotMessageService, aiBotConversationMessageMapper, aiBotConversationMessageService); final HistoriesPrompt historiesPrompt = new HistoriesPrompt(); historiesPrompt.setSystemMessage(SystemMessage.of((String) llmOptions.get("systemPrompt"))); historiesPrompt.setMemory(memory); HumanMessage humanMessage = new HumanMessage(prompt); // 添加插件相关的function calling appendPluginFunctions(botId, humanMessage); //添加工作流相关的 Function Calling appendWorkflowFunctions(botId, humanMessage); //添加知识库相关的 Function Calling appendKnowledgeFunctions(botId, humanMessage); historiesPrompt.addMessage(humanMessage); MySseEmitter emitter = new MySseEmitter((long) (1000 * 60 * 2)); final Boolean[] needClose = {true}; if (humanMessage.getFunctions() != null && !humanMessage.getFunctions().isEmpty()) { try { AiMessageResponse aiMessageResponse = llm.chat(historiesPrompt); function_call(aiMessageResponse, emitter, needClose, historiesPrompt, llm, prompt, false, false); } catch (Exception e) { emitter.completeWithError(e); } if (needClose[0]) { System.out.println("function chat complete"); emitter.complete(); } } else { llm.chatStream(historiesPrompt, new StreamResponseListener() { @Override public void onMessage(ChatContext context, AiMessageResponse response) { try { function_call(response, emitter, needClose, historiesPrompt, llm, prompt, false,true); } catch (Exception e) { emitter.completeWithError(e); } } @Override public void onStop(ChatContext context) { if (needClose[0]) { System.out.println("normal chat complete"); emitter.complete(); } } @Override public void onFailure(ChatContext context, Throwable throwable) { emitter.completeWithError(throwable); } }); } return emitter; } /** * 外部用户调用智能体进行对话 * 需要用户传 apiKey 对用户进行身份验证 * @return * @param stream [true: 返回sse false: 返回json */ @SaIgnore @PostMapping("externalChat") public Object externalChat( @JsonBody(value = "messages", required = true) List messages, @JsonBody(value = "botId", required = true) BigInteger botId, @JsonBody(value = "stream", required = false) boolean stream, HttpServletResponse response, HttpServletRequest request ) { // 设置响应类型 if (stream) { response.setContentType("text/event-stream"); } else { response.setContentType("application/json"); } // 获取 API Key 和 Bot 信息 String apiKey = request.getHeader("Authorization"); AiBot aiBot = service.getById(botId); if (aiBot == null) { return createResponse(stream, "机器人不存在"); } Map llmOptions = aiBot.getLlmOptions(); AiLlm aiLlm = aiLlmService.getById(aiBot.getLlmId()); if (aiLlm == null) { return createResponse(stream, "LLM不存在"); } Llm llm = aiLlm.toLlm(); AiBotExternalMessageMemory messageMemory = new AiBotExternalMessageMemory(messages); HistoriesPrompt historiesPrompt = new HistoriesPrompt(); historiesPrompt.setSystemMessage(SystemMessage.of((String) llmOptions.get("systemPrompt"))); historiesPrompt.setMemory(messageMemory); String prompt = messages.get(messages.size() - 1).getContent(); HumanMessage humanMessage = new HumanMessage(); // 添加插件、工作流、知识库相关的 Function Calling appendPluginFunctions(botId, humanMessage); appendWorkflowFunctions(botId, humanMessage); appendKnowledgeFunctions(botId, humanMessage); historiesPrompt.addMessage(humanMessage); // 根据 responseType 返回不同的响应 if (stream) { MySseEmitter emitter = new MySseEmitter((long) (1000 * 60 * 2)); final Boolean[] needClose = {true}; if (humanMessage.getFunctions() != null && !humanMessage.getFunctions().isEmpty()) { try { AiMessageResponse aiMessageResponse = llm.chat(historiesPrompt); function_call(aiMessageResponse, emitter, needClose, historiesPrompt, llm, prompt, true, true); } catch (Exception e) { emitter.completeWithError(e); } if (needClose[0]) { System.out.println("function chat complete"); emitter.complete(); } } else { llm.chatStream(historiesPrompt, new StreamResponseListener() { @Override public void onMessage(ChatContext context, AiMessageResponse response) { try { function_call(response, emitter, needClose, historiesPrompt, llm, prompt, true, true); } catch (Exception e) { emitter.completeWithError(e); } } @Override public void onStop(ChatContext context) { if (needClose[0]) { System.out.println("normal chat complete"); emitter.complete(); } } @Override public void onFailure(ChatContext context, Throwable throwable) { emitter.completeWithError(throwable); } }); } return emitter; } else { AiMessageResponse resultFunctionCall; if (humanMessage.getFunctions() != null && !humanMessage.getFunctions().isEmpty()) { try { AiMessageResponse aiMessageResponse = llm.chat(historiesPrompt); resultFunctionCall = jsonResultFunctionCall(aiMessageResponse, historiesPrompt, llm, prompt); return JSON.toJSONString(resultFunctionCall.getMessage()); } catch (Exception e) { return createErrorResponse(e); } } else { AiMessageResponse messageResponse = llm.chat(historiesPrompt); resultFunctionCall = jsonResultFunctionCall(messageResponse, historiesPrompt, llm, prompt); AiBotExternalMsgJsonResult result = handleMessageResult(resultFunctionCall.getMessage()); return JSON.toJSONString(result); } } } private AiBotExternalMsgJsonResult handleMessageResult(AiMessage aiMessage) { AiBotExternalMsgJsonResult messageResult = new AiBotExternalMsgJsonResult(); messageResult.setCreated(new Date().getTime()); AiBotExternalMsgJsonResult.Usage usage = new AiBotExternalMsgJsonResult.Usage(); usage.setTotalTokens(aiMessage.getTotalTokens()); usage.setCompletionTokens(aiMessage.getCompletionTokens()); usage.setPromptTokens(aiMessage.getPromptTokens()); messageResult.setUsage(usage); AiBotExternalMsgJsonResult.Choice choice = new AiBotExternalMsgJsonResult.Choice(); AiBotExternalMsgJsonResult.Message message = new AiBotExternalMsgJsonResult.Message(); message.setContent(aiMessage.getContent()); message.setRole("assistant"); choice.setMessage(message); messageResult.setChoices(choice); messageResult.setStatus(aiMessage.getStatus().name()); return messageResult; } // 辅助方法:创建响应 private Object createResponse(boolean stream, String content) { if (stream) { MySseEmitter emitter = new MySseEmitter((long) (1000 * 60 * 2)); emitter.send(content); emitter.complete(); return emitter; } else { return ResponseEntity.ok(content); } } // 辅助方法:创建错误响应 private Object createErrorResponse(Exception e) { return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(e.getMessage()); } /** * * @param aiMessageResponse * @param emitter * @param needClose * @param historiesPrompt * @param llm * @param prompt * @param isChatApi 该参数用于分辨外部地址用户通过apiKey的方式传参,不涉及外部调用的默认 false * @param stream */ private String function_call(AiMessageResponse aiMessageResponse, MySseEmitter emitter, Boolean[] needClose, HistoriesPrompt historiesPrompt, Llm llm, String prompt, boolean isChatApi,boolean stream) { ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); RequestContextHolder.setRequestAttributes(sra, true); String content = aiMessageResponse.getMessage().getContent(); Object messageContent = aiMessageResponse.getMessage(); if (StringUtil.hasText(content)) { emitter.send(messageContent.toString()); } List functionCallers = aiMessageResponse.getFunctionCallers(); if (CollectionUtil.hasItems(functionCallers)) { needClose[0] = false; for (FunctionCaller functionCaller : functionCallers) { Object result = functionCaller.call(); if (ObjectUtil.isNotEmpty(result)) { String newPrompt = "请根据以下内容回答用户,内容是:\n" + result + "\n 用户的问题是:" + prompt; historiesPrompt.addMessageTemporary(new HumanMessage(newPrompt)); llm.chatStream(historiesPrompt, new StreamResponseListener() { @Override public void onMessage(ChatContext context, AiMessageResponse response) { needClose[0] = true; String content = response.getMessage().getContent(); Object messageContent = response.getMessage(); if (StringUtil.hasText(content)) { String jsonResult = JSON.toJSONString(messageContent); emitter.send(jsonResult); } } @Override public void onStop(ChatContext context) { if (needClose[0]) { System.out.println("function chat complete"); emitter.complete(); } historiesPrompt.clearTemporaryMessages(); } @Override public void onFailure(ChatContext context, Throwable throwable) { emitter.completeWithError(throwable); } }); } } } return JSON.toJSONString(messageContent); } private AiMessageResponse jsonResultFunctionCall(AiMessageResponse aiMessageResponse , HistoriesPrompt historiesPrompt, Llm llm, String prompt) { List functionCallers = aiMessageResponse.getFunctionCallers(); if (CollectionUtil.hasItems(functionCallers)) { for (FunctionCaller functionCaller : functionCallers) { Object result = functionCaller.call(); if (ObjectUtil.isNotEmpty(result)) { String newPrompt = "请根据以下内容回答用户,内容是:\n" + result + "\n 用户的问题是:" + prompt; historiesPrompt.addMessageTemporary(new HumanMessage(newPrompt)); return llm.chat(historiesPrompt); } } } return aiMessageResponse; } private void appendWorkflowFunctions(BigInteger botId, HumanMessage humanMessage) { QueryWrapper queryWrapper = QueryWrapper.create().eq(AiBotWorkflow::getBotId, botId); List aiBotWorkflows = aiBotWorkflowService.getMapper().selectListWithRelationsByQuery(queryWrapper); if (aiBotWorkflows != null) { for (AiBotWorkflow aiBotWorkflow : aiBotWorkflows) { Function function = aiBotWorkflow.getWorkflow().toFunction(); humanMessage.addFunction(function); } } } private void appendKnowledgeFunctions(BigInteger botId, HumanMessage humanMessage) { QueryWrapper queryWrapper = QueryWrapper.create().eq(AiBotKnowledge::getBotId, botId); List aiBotKnowledges = aiBotKnowledgeService.getMapper().selectListWithRelationsByQuery(queryWrapper); if (aiBotKnowledges != null) { for (AiBotKnowledge aiBotKnowledge : aiBotKnowledges) { Function function = aiBotKnowledge.getKnowledge().toFunction(); humanMessage.addFunction(function); } } } private void appendPluginFunctions(BigInteger botId, HumanMessage humanMessage) { QueryWrapper queryWrapper = QueryWrapper.create().eq(AiBotPlugins::getBotId, botId); List aiBotPlugins = aiBotPluginsService.getMapper().selectListWithRelationsByQuery(queryWrapper); if (cn.hutool.core.collection.CollectionUtil.isNotEmpty(aiBotPlugins)) { for (AiBotPlugins aiBotPlugin : aiBotPlugins) { Function function = aiBotPlugin.getAiPlugins().toFunction(); humanMessage.addFunction(function); } } } }