package tech.aiflowy.ai.config; import com.agentsflex.core.llm.response.AiMessageResponse; import com.agentsflex.core.message.AiMessage; import com.agentsflex.core.message.Message; import com.agentsflex.core.prompt.HistoriesPrompt; import com.agentsflex.core.prompt.Prompt; import com.alibaba.fastjson.JSON; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.mybatisflex.core.query.QueryWrapper; import okhttp3.*; import okio.BufferedSource; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import tech.aiflowy.ai.entity.AiBotConversationMessage; import tech.aiflowy.ai.entity.AiBotMessage; import tech.aiflowy.ai.service.AiBotConversationMessageService; import tech.aiflowy.ai.service.AiBotMessageService; import tech.aiflowy.ai.service.impl.AiBotMessageServiceImpl; import tech.aiflowy.common.ai.MySseEmitter; import tech.aiflowy.common.util.StringUtil; import tech.aiflowy.common.web.controller.BaseCurdController; import javax.annotation.Resource; import java.io.IOException; import java.math.BigInteger; import java.util.*; import java.util.concurrent.CompletableFuture; public class DifyStreamClient { private final OkHttpClient client; private final String apiUrl; private final String apiKey; private final Gson gson; private String prompt; private AiBotMessageService aiBotMessageService; public DifyStreamClient(String apiUrl, String apiKey, AiBotMessageService aiBotMessageService) { this.apiUrl = apiUrl; this.apiKey = apiKey; this.gson = new GsonBuilder().setPrettyPrinting().create(); this.client = new OkHttpClient.Builder().build(); this.aiBotMessageService = aiBotMessageService; } // 流式聊天方法 - 直接集成SseEmitter public CompletableFuture chatStream(String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) { prompt = message; QueryWrapper qw = new QueryWrapper(); qw.eq(AiBotMessage::getSessionId, sessionId) .orderBy(AiBotMessage::getId,false) .limit(6); List history = aiBotMessageService.list(qw); // System.out.println("======================history==================\n"); // for (AiBotMessage aiBotMessage : history) { // System.out.println(aiBotMessage.toMessage()); // } // System.out.println("\n======================history=================="); // 构建请求JSON JsonObject requestBody = new JsonObject(); requestBody.add("inputs", new JsonObject()); requestBody.addProperty("query", message); requestBody.addProperty("response_mode", "streaming"); requestBody.addProperty("conversation_id", ""); requestBody.addProperty("user", userId); requestBody.add("files", new JsonArray()); // 添加历史对话信息 JsonArray historyArray = new JsonArray(); for (AiBotMessage msg : history) { historyArray.add(String.valueOf(msg)); } requestBody.add("history", historyArray); RequestBody body = RequestBody.create( gson.toJson(requestBody), MediaType.parse("application/json; charset=utf-8") ); Request request = new Request.Builder() .url(apiUrl) .post(body) .header("Authorization", apiKey) .header("Content-Type", "application/json") .build(); CompletableFuture future = new CompletableFuture<>(); // 设置SseEmitter生命周期回调 emitter.onTimeout(() -> { System.out.println("SSE连接超时"); emitter.complete(); future.complete(null); }); emitter.onCompletion(() -> { System.out.println("SSE连接已完成"); future.complete(null); }); emitter.onError(e -> { System.out.println("SSE连接错误: " + e.getMessage()); emitter.completeWithError(e); future.completeExceptionally(e); }); // 发送异步请求 client.newCall(request).enqueue(new Callback() { @Override public void onFailure(Call call, IOException e) { emitter.completeWithError(e); future.completeExceptionally(e); } @Override public void onResponse(Call call, Response response) { AiMessage aiMessage = new AiMessage(); com.agentsflex.core.llm.response.AiMessageResponse aiMessageResponse = new com.agentsflex.core.llm.response.AiMessageResponse(new HistoriesPrompt(), response.message(), aiMessage); try (ResponseBody responseBody = response.body()) { if (!response.isSuccessful()) { emitter.completeWithError(new IOException("API错误: " + response.code())); return; } // 使用BufferedSource逐行读取响应内容 BufferedSource source = responseBody.source(); String line; StringBuffer sb = new StringBuffer(); // 标记是否为第一条有效数据(用于处理某些API的特殊格式) boolean isFirstData = true; while ((line = source.readUtf8Line()) != null) { if (line.startsWith("data: ")) { String data = line.substring(6).trim(); // 忽略空数据或结束标记 if (data.isEmpty() || data.equals("[DONE]")) { continue; } try { // 根据实际API响应结构提取消息内容 // 这里需要根据具体API调整路径 JsonObject messageObj = gson.fromJson(data, JsonObject.class); // System.out.println(messageObj); if(messageObj.get("event").getAsString().equals("message_end")){ AiBotMessage aiBotMessage = new AiBotMessage(); aiBotMessage.setBotId(botId); aiBotMessage.setSessionId(sessionId); aiBotMessage.setAccountId(new BigInteger(userId)); aiBotMessage.setRole("assistant"); aiBotMessage.setContent(sb.toString()); aiBotMessage.setCreated(new Date()); aiBotMessage.setIsExternalMsg(1); aiBotMessageService.save(aiBotMessage); // System.out.println("end"); }else{ String context = messageObj.get("answer").getAsString(); if (context != null) { // // 只移除HTML标签,保留Markdown特殊字符 // context = context.replaceAll("(?i)<[^>]*>", ""); context = context.replaceFirst("Thinking...", ""); } sb.append(context); aiMessage.setContent(context); aiMessageResponse.setMessage(aiMessage); if (StringUtil.hasText(messageObj.get("answer").getAsString())) { // System.out.println(aiMessage); if(aiMessage.getContent().startsWith("")){ aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", ""+"\n\n")); } // 发送消息片段给前端 emitter.send(JSON.toJSONString(aiMessage)); } } // 重置第一条数据标记 isFirstData = false; } catch (Exception e) { // 记录解析错误但继续处理后续数据 emitter.completeWithError(e); } } } // 所有数据处理完毕,发送完成信号 emitter.send(SseEmitter.event().name("complete")); emitter.complete(); } catch (IOException e) { emitter.completeWithError(e); emitter.complete(); } catch (Exception e) { // 处理其他异常 emitter.completeWithError(e); } } // @Override public void onResponses(Call call, Response response) { // try (ResponseBody responseBody = response.body()) { // if (!response.isSuccessful() || responseBody == null) { // IOException e = new IOException("Unexpected code " + response); // emitter.completeWithError(e); // future.completeExceptionally(e); // return; // } //// String rawResponse = responseBody.string(); // 打印原始响应 //// System.out.println("Dify原始响应: " + rawResponse); // // BufferedSource source = responseBody.source(); // String line; // // // 处理流式响应 // while ((line = source.readUtf8Line()) != null) { // if (line.startsWith("data: ")) { // String data = line.substring(6); // 移除"data: "前缀 // // // 跳过空数据 // if (data.trim().equals("[DONE]")) { // emitter.complete(); // future.complete(null); // break; // } // // try { // // 解析JSON响应 // JsonObject jsonResponse = gson.fromJson(data, JsonObject.class); // // 发送消息内容到客户端 // if (jsonResponse.has("message") && jsonResponse.get("message").isJsonObject()) { // String content = jsonResponse.getAsJsonObject("message").get("content").getAsString(); // emitter.send(SseEmitter.event() // .data(content) // .name("message") // ); // } // // // 处理函数调用 // if (jsonResponse.has("function_call") && jsonResponse.get("function_call").isJsonObject()) { // emitter.send(SseEmitter.event() // .data(jsonResponse.get("function_call")) // .name("function_call") // ); // } // // } catch (Exception e) { // emitter.completeWithError(e); // future.completeExceptionally(e); // break; // } // } // } try (ResponseBody responseBody = response.body()) { if (!response.isSuccessful()) { emitter.completeWithError(new IOException("API错误")); return; } StringBuilder fullAnswer = new StringBuilder(); BufferedSource source = responseBody.source(); String line; while ((line = source.readUtf8Line()) != null) { if (line.startsWith("data: ")) { String data = line.substring(6).trim(); if (data.equals("[DONE]")) { // 发送完整回答并结束流 emitter.send(fullAnswer.toString()); emitter.complete(); break; } try { JsonObject json = gson.fromJson(data, JsonObject.class); System.out.println(json); String answerFragment = json.get("answer").getAsString(); // 过滤无效字符(如️⃣、\n) String cleanedFragment = answerFragment.replaceAll("[^\u4e00-\u9fa50-9a-zA-Z\\s\\p{Punct}]", ""); fullAnswer.append(cleanedFragment); // 拼接碎片 // 可选:每拼接一部分发送给前端(需前端支持增量渲染) // emitter.send(SseEmitter.event().data(cleanedFragment)); } catch (Exception e) { emitter.sendAndComplete("解析错误: " + e.getMessage()); } } } } catch (Exception e) { emitter.completeWithError(e); future.completeExceptionally(e); } } }); return future; } // 数据模型类 public interface StreamResponseListener { void onMessage(ChatContext context, AiMessageResponse response); void onStop(ChatContext context); void onFailure(ChatContext context, Throwable throwable); } public static class ChatContext { // 上下文信息 } public static class AiMessageResponse { private Message message; private JsonArray functionCallers; public Message getMessage() { return message; } public void setMessage(Message message) { this.message = message; } public JsonArray getFunctionCallers() { return functionCallers; } public void setFunctionCallers(JsonArray functionCallers) { this.functionCallers = functionCallers; } } public static class Message { private String content; public String getContent() { return content; } public void setContent(String content) { this.content = content; } } }