package tech.aiflowy.ai.config; import com.agentsflex.core.message.AiMessage; import com.agentsflex.core.prompt.HistoriesPrompt; import com.alibaba.fastjson.JSON; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.gson.*; import com.mybatisflex.core.query.QueryWrapper; import okhttp3.*; import okio.BufferedSource; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import tech.aiflowy.ai.entity.AiBotMessage; import tech.aiflowy.ai.service.AiBotMessageService; import tech.aiflowy.common.ai.MySseEmitter; import java.io.File; import java.io.IOException; import java.math.BigInteger; import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; public class DifyStreamClient { private final OkHttpClient client; private final String apiUrl; private final String apiKey; private final Gson gson; private String prompt; private AiBotMessageService aiBotMessageService; boolean blean = false; public DifyStreamClient(String apiUrl, String apiKey, AiBotMessageService aiBotMessageService) { this.apiUrl = apiUrl; this.apiKey = apiKey; this.gson = new GsonBuilder().setPrettyPrinting().create(); this.client = new OkHttpClient.Builder().connectTimeout(60, TimeUnit.SECONDS) // 连接超时 .readTimeout(120, TimeUnit.SECONDS).build(); this.aiBotMessageService = aiBotMessageService; } // 在DifyStreamClient类中添加以下方法 public CompletableFuture runWorkflow(Map inputs, String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) { // 构建请求JSON JsonObject requestBody = new JsonObject(); // 添加inputs参数 JsonObject inputsJson = new JsonObject(); if (inputs != null) { for (Map.Entry entry : inputs.entrySet()) { if (entry.getValue() instanceof String) { inputsJson.addProperty(entry.getKey(), (String) entry.getValue()); } else if (entry.getValue() instanceof Number) { inputsJson.addProperty(entry.getKey(), (Number) entry.getValue()); } else if (entry.getValue() instanceof Boolean) { inputsJson.addProperty(entry.getKey(), (Boolean) entry.getValue()); } else { // 对于复杂对象,转换为JSON字符串 inputsJson.add(entry.getKey(), gson.toJsonTree(entry.getValue())); } } } requestBody.add("inputs", inputsJson); // 设置响应模式和用户ID requestBody.addProperty("response_mode", "streaming"); requestBody.addProperty("user", userId); // System.out.println(requestBody+"=============================================================================================="); // 创建请求 RequestBody body = RequestBody.create( gson.toJson(requestBody), MediaType.parse("application/json; charset=utf-8") ); Request request = new Request.Builder() .url(apiUrl) .post(body) .header("Authorization", apiKey) .header("Content-Type", "application/json") .build(); CompletableFuture future = new CompletableFuture<>(); // 设置SseEmitter生命周期回调 emitter.onTimeout(() -> { System.out.println("SSE连接超时"); emitter.complete(); future.complete(null); }); emitter.onCompletion(() -> { System.out.println("SSE连接已完成"); future.complete(null); }); emitter.onError(e -> { System.out.println("SSE连接错误: " + e.getMessage()); emitter.completeWithError(e); future.completeExceptionally(e); }); // 发送异步请求 client.newCall(request).enqueue(new Callback() { @Override public void onFailure(Call call, IOException e) { emitter.completeWithError(e); future.completeExceptionally(e); } @Override public void onResponse(Call call, Response response) { StringBuffer sb = new StringBuffer(); try (ResponseBody responseBody = response.body()) { if (!response.isSuccessful()) { emitter.completeWithError(new IOException("API错误: " + response.code())); return; } // 使用BufferedSource逐行读取响应内容 BufferedSource source = responseBody.source(); String line; while ((line = source.readUtf8Line()) != null) { if (line.startsWith("data: ")) { String data = line.substring(6).trim(); // 忽略空数据或结束标记 if (data.isEmpty() || data.equals("[DONE]")) { continue; } try { // 这里需要根据实际API返回结构调整 // 假设API返回的格式是{ "output": "消息内容" } JsonObject jsonObject = gson.fromJson(data, JsonObject.class); String title = null; // System.out.println(jsonObject); if (jsonObject != null && jsonObject.has("data")) { JsonElement dataElement = jsonObject.get("data"); if (dataElement != null && !dataElement.isJsonNull()) { JsonObject dataObject = dataElement.getAsJsonObject(); if (dataObject != null && dataObject.has("node_type")) { continue; } if (dataObject != null && dataObject.has("title")) { JsonElement titleElement = dataObject.get("title"); if (titleElement != null && !titleElement.isJsonNull()) { title = titleElement.getAsString(); } } else if (dataObject != null && dataObject.has("text")) { JsonElement titleElement = dataObject.get("text"); if (titleElement != null && !titleElement.isJsonNull()) { title = titleElement.getAsString(); } }else if (dataObject != null && dataObject.has("outputs")) { JsonElement titleElement = dataObject.get("outputs"); if (titleElement != null && !titleElement.isJsonNull()) { AiBotMessage aiBotMessage = new AiBotMessage(); aiBotMessage.setBotId(botId); aiBotMessage.setSessionId(sessionId); aiBotMessage.setAccountId(new BigInteger(userId)); aiBotMessage.setRole("assistant"); aiBotMessage.setContent(sb.toString()); aiBotMessage.setCreated(new Date()); aiBotMessage.setIsExternalMsg(1); aiBotMessageService.save(aiBotMessage); // dataObject = titleElement.getAsJsonObject(); // if (dataObject != null && dataObject.has("text")) { // titleElement = dataObject.get("text"); // if (titleElement != null && !titleElement.isJsonNull()) { // title = titleElement.getAsString(); // } // }else if (dataObject != null && dataObject.has("data")) { // titleElement = dataObject.get("data"); // if (titleElement != null && !titleElement.isJsonNull()) { // title = titleElement.getAsString(); // } // } } } } } // 创建消息对象并发送给前端 AiMessage aiMessage = new AiMessage(); aiMessage.setContent(title); System.out.println(gson.fromJson(data, JsonObject.class)); sb.append(aiMessage.getContent()); // 将消息发送给前端 if (aiMessage.getContent() != null){ emitter.send(JSON.toJSONString(aiMessage)); } } catch (Exception e) { // 记录解析错误但继续处理后续数据 System.err.println("解析响应数据时出错: " + e.getMessage()); emitter.completeWithError(e); } } } // 所有数据处理完毕,发送完成信号 emitter.send(SseEmitter.event().name("complete")); emitter.complete(); } catch (IOException e) { emitter.completeWithError(e); emitter.complete(); } catch (Exception e) { // 处理其他异常 emitter.completeWithError(e); } } }); return future; } public String fileUpload(String userId, MultipartFile file) { // 用户标识,要和发送消息接口的 user 保持一致 String user = userId; // 判断文件是否为空 if (file.isEmpty()) { System.out.println("上传文件为空"); return "上传文件为空"; } OkHttpClient client = new OkHttpClient(); // 构建 multipart/form-data 请求体 MultipartBody.Builder requestBodyBuilder = null; try { requestBodyBuilder = new MultipartBody.Builder() .setType(MultipartBody.FORM) // 添加上传文件,这里直接用 MultipartFile 的字节数组构建请求体 .addFormDataPart("file", file.getOriginalFilename(), RequestBody.create(MediaType.parse("application/octet-stream"), file.getBytes())) .addFormDataPart("user", user); } catch (IOException e) { throw new RuntimeException(e); } RequestBody requestBody = requestBodyBuilder.build(); Request request = new Request.Builder() .url(apiUrl) // 替换为实际的接口地址常量 .post(requestBody) .header("Authorization", apiKey) // 替换为实际的授权密钥常量 .build(); try (Response response = client.newCall(request).execute()) { if (!response.isSuccessful()) { System.out.println("请求失败,状态码:" + response.code()); return user; } String responseBody = response.body().string(); System.out.println("上传结果:" + responseBody); return responseBody; } catch (IOException e) { e.printStackTrace(); return "文件上传失败:" + e.getMessage(); } } // 流式聊天方法 - 直接集成SseEmitter public CompletableFuture chatStream(String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) { prompt = message; QueryWrapper qw = new QueryWrapper(); qw.eq(AiBotMessage::getSessionId, sessionId) .orderBy(AiBotMessage::getId,false) .limit(6); List history = aiBotMessageService.list(qw); // System.out.println("======================history==================\n"); // for (AiBotMessage aiBotMessage : history) { // System.out.println(aiBotMessage.toMessage()); // } // System.out.println("\n======================history=================="); // 构建请求JSON JsonObject requestBody = new JsonObject(); requestBody.add("inputs", new JsonObject()); requestBody.addProperty("query", message); requestBody.addProperty("response_mode", "streaming"); requestBody.addProperty("conversation_id", ""); requestBody.addProperty("user", userId); // requestBody.add("files", new JsonArray()); // 添加历史对话信息 JsonArray historyArray = new JsonArray(); // for (AiBotMessage msg : history) { // historyArray.add(String.valueOf(msg)); // } // requestBody.add("history", historyArray); RequestBody body = RequestBody.create( gson.toJson(requestBody), MediaType.parse("application/json; charset=utf-8") ); Request request = new Request.Builder() .url(apiUrl) .post(body) .header("Authorization", apiKey) .header("Content-Type", "application/json") .build(); CompletableFuture future = new CompletableFuture<>(); // 设置SseEmitter生命周期回调 emitter.onTimeout(() -> { System.out.println("SSE连接超时"); emitter.complete(); future.complete(null); }); emitter.onCompletion(() -> { System.out.println("SSE连接已完成"); future.complete(null); }); emitter.onError(e -> { System.out.println("SSE连接错误: " + e.getMessage()); emitter.completeWithError(e); future.completeExceptionally(e); }); // 发送异步请求 client.newCall(request).enqueue(new Callback() { @Override public void onFailure(Call call, IOException e) { emitter.completeWithError(e); future.completeExceptionally(e); } @Override public void onResponse(Call call, Response response) { int a = 1; AiMessage aiMessage = new AiMessage(); com.agentsflex.core.llm.response.AiMessageResponse aiMessageResponse = new com.agentsflex.core.llm.response.AiMessageResponse(new HistoriesPrompt(), response.message(), aiMessage); try (ResponseBody responseBody = response.body()) { if (!response.isSuccessful()) { emitter.completeWithError(new IOException("API错误: " + response.code())); return; } // 使用BufferedSource逐行读取响应内容 BufferedSource source = responseBody.source(); String line; StringBuffer sb = new StringBuffer(); // 标记是否为第一条有效数据(用于处理某些API的特殊格式) boolean isFirstData = true; while ((line = source.readUtf8Line()) != null) { if (line.startsWith("data: ")) { String data = line.substring(6).trim(); // 忽略空数据或结束标记 if (data.isEmpty() || data.equals("[DONE]")) { continue; } try { // 根据实际API响应结构提取消息内容 // 这里需要根据具体API调整路径 JsonObject messageObj = gson.fromJson(data, JsonObject.class); // System.out.println(messageObj); if(!messageObj.has("answer")){ try { JsonArray asJsonArray = messageObj.getAsJsonObject("metadata").getAsJsonArray("retriever_resources"); if (asJsonArray.size() > 0) { aiMessage.setFullContent("-----------------------"); sb.append("\n"+aiMessage.getFullContent()); emitter.send(JSON.toJSONString(aiMessage)); for (int i = 0; i < asJsonArray.size(); i++) { aiMessage.setFullContent(asJsonArray.get(i).getAsJsonObject().get("document_name").getAsString()); aiMessage.setContent(null); // aiMessageResponse.setMessage(aiMessage); sb.append("\n"+aiMessage.getFullContent()); emitter.send(JSON.toJSONString(aiMessage)); } } } catch (Exception e) { System.out.println("meizuo"); } AiBotMessage aiBotMessage = new AiBotMessage(); aiBotMessage.setBotId(botId); aiBotMessage.setSessionId(sessionId); aiBotMessage.setAccountId(new BigInteger(userId)); aiBotMessage.setRole("assistant"); String content = aiBotMessage.getContent(); aiBotMessage.setContent(sb.toString()); aiBotMessage.setCreated(new Date()); aiBotMessage.setIsExternalMsg(1); if(a == 1){ a = 0; aiBotMessageService.save(aiBotMessage); }else{ QueryWrapper qw = new QueryWrapper(); qw.eq("content", content); aiBotMessageService.remove(qw); aiBotMessageService.save(aiBotMessage); } // System.out.println("end"); }else{ String context = messageObj.get("answer").getAsString(); // System.out.println(context); if (context != null) { // // 只移除HTML标签,保留Markdown特殊字符 // context = context.replaceAll("(?i)<[^>]*>", ""); context = context.replaceFirst("Thinking...", ""); if(context.startsWith("]*>", ""); } } aiMessage.setContent(context); aiMessageResponse.setMessage(aiMessage); if (!messageObj.get("answer").getAsString().isEmpty()) { // if(!blean && aiMessage.getContent().startsWith("")){ // blean = true; // aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "\n\n")); // } // if(blean){ sb.append(aiMessage.getContent()); // System.out.println(aiMessage); // 发送消息片段给前端 emitter.send(JSON.toJSONString(aiMessage)); // } } } // 重置第一条数据标记 isFirstData = false; } catch (Exception e) { // 记录解析错误但继续处理后续数据 emitter.completeWithError(e); } } } // 所有数据处理完毕,发送完成信号 emitter.send(SseEmitter.event().name("complete")); emitter.complete(); } catch (IOException e) { emitter.completeWithError(e); emitter.complete(); } catch (Exception e) { // 处理其他异常 emitter.completeWithError(e); } } }); return future; } // 数据模型类 public interface StreamResponseListener { void onMessage(ChatContext context, AiMessageResponse response); void onStop(ChatContext context); void onFailure(ChatContext context, Throwable throwable); } public static class ChatContext { // 上下文信息 } public static class AiMessageResponse { private Message message; private JsonArray functionCallers; public Message getMessage() { return message; } public void setMessage(Message message) { this.message = message; } public JsonArray getFunctionCallers() { return functionCallers; } public void setFunctionCallers(JsonArray functionCallers) { this.functionCallers = functionCallers; } } public static class Message { private String content; public String getContent() { return content; } public void setContent(String content) { this.content = content; } } }