| | |
| | | package tech.aiflowy.ai.config; |
| | | |
| | | import com.agentsflex.core.llm.response.AiMessageResponse; |
| | | import com.agentsflex.core.message.AiMessage; |
| | | import com.agentsflex.core.message.Message; |
| | | import com.agentsflex.core.prompt.HistoriesPrompt; |
| | | import com.agentsflex.core.prompt.Prompt; |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.google.gson.Gson; |
| | | import com.google.gson.GsonBuilder; |
| | | import com.google.gson.JsonArray; |
| | | import com.google.gson.JsonObject; |
| | | import com.google.gson.*; |
| | | import com.mybatisflex.core.query.QueryWrapper; |
| | | import okhttp3.*; |
| | | import okio.BufferedSource; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| | | import tech.aiflowy.ai.entity.AiBotConversationMessage; |
| | | import tech.aiflowy.ai.entity.AiBotMessage; |
| | | import tech.aiflowy.ai.service.AiBotConversationMessageService; |
| | | import tech.aiflowy.ai.service.AiBotMessageService; |
| | | import tech.aiflowy.ai.service.impl.AiBotMessageServiceImpl; |
| | | import tech.aiflowy.common.ai.MySseEmitter; |
| | | import tech.aiflowy.common.util.StringUtil; |
| | | import tech.aiflowy.common.web.controller.BaseCurdController; |
| | | |
| | | import javax.annotation.Resource; |
| | | import java.io.File; |
| | | import java.io.IOException; |
| | | import java.math.BigInteger; |
| | | import java.util.*; |
| | |
| | | private final Gson gson; |
| | | private String prompt; |
| | | private AiBotMessageService aiBotMessageService; |
| | | boolean blean = false; |
| | | |
| | | public DifyStreamClient(String apiUrl, String apiKey, AiBotMessageService aiBotMessageService) { |
| | | this.apiUrl = apiUrl; |
| | |
| | | this.gson = new GsonBuilder().setPrettyPrinting().create(); |
| | | this.client = new OkHttpClient.Builder().build(); |
| | | this.aiBotMessageService = aiBotMessageService; |
| | | } |
| | | |
| | | // 在DifyStreamClient类中添加以下方法 |
| | | public CompletableFuture<Void> runWorkflow(Map<String, Object> inputs, String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) { |
| | | // 构建请求JSON |
| | | JsonObject requestBody = new JsonObject(); |
| | | // 添加inputs参数 |
| | | JsonObject inputsJson = new JsonObject(); |
| | | if (inputs != null) { |
| | | for (Map.Entry<String, Object> entry : inputs.entrySet()) { |
| | | if (entry.getValue() instanceof String) { |
| | | inputsJson.addProperty(entry.getKey(), (String) entry.getValue()); |
| | | } else if (entry.getValue() instanceof Number) { |
| | | inputsJson.addProperty(entry.getKey(), (Number) entry.getValue()); |
| | | } else if (entry.getValue() instanceof Boolean) { |
| | | inputsJson.addProperty(entry.getKey(), (Boolean) entry.getValue()); |
| | | } else { |
| | | // 对于复杂对象,转换为JSON字符串 |
| | | inputsJson.add(entry.getKey(), gson.toJsonTree(entry.getValue())); |
| | | } |
| | | } |
| | | } |
| | | requestBody.add("inputs", inputsJson); |
| | | |
| | | // 设置响应模式和用户ID |
| | | requestBody.addProperty("response_mode", "streaming"); |
| | | requestBody.addProperty("user", userId); |
| | | |
| | | // 创建请求 |
| | | RequestBody body = RequestBody.create( |
| | | gson.toJson(requestBody), |
| | | MediaType.parse("application/json; charset=utf-8") |
| | | ); |
| | | |
| | | Request request = new Request.Builder() |
| | | .url(apiUrl) |
| | | .post(body) |
| | | .header("Authorization", apiKey) |
| | | .header("Content-Type", "application/json") |
| | | .build(); |
| | | |
| | | CompletableFuture<Void> future = new CompletableFuture<>(); |
| | | |
| | | // 设置SseEmitter生命周期回调 |
| | | emitter.onTimeout(() -> { |
| | | System.out.println("SSE连接超时"); |
| | | emitter.complete(); |
| | | future.complete(null); |
| | | }); |
| | | |
| | | emitter.onCompletion(() -> { |
| | | System.out.println("SSE连接已完成"); |
| | | future.complete(null); |
| | | }); |
| | | |
| | | emitter.onError(e -> { |
| | | System.out.println("SSE连接错误: " + e.getMessage()); |
| | | emitter.completeWithError(e); |
| | | future.completeExceptionally(e); |
| | | }); |
| | | |
| | | // 发送异步请求 |
| | | client.newCall(request).enqueue(new Callback() { |
| | | @Override |
| | | public void onFailure(Call call, IOException e) { |
| | | emitter.completeWithError(e); |
| | | future.completeExceptionally(e); |
| | | } |
| | | |
| | | @Override |
| | | public void onResponse(Call call, Response response) { |
| | | try (ResponseBody responseBody = response.body()) { |
| | | if (!response.isSuccessful()) { |
| | | emitter.completeWithError(new IOException("API错误: " + response.code())); |
| | | return; |
| | | } |
| | | |
| | | // 使用BufferedSource逐行读取响应内容 |
| | | BufferedSource source = responseBody.source(); |
| | | String line; |
| | | |
| | | while ((line = source.readUtf8Line()) != null) { |
| | | if (line.startsWith("data: ")) { |
| | | String data = line.substring(6).trim(); |
| | | |
| | | // 忽略空数据或结束标记 |
| | | if (data.isEmpty() || data.equals("[DONE]")) { |
| | | continue; |
| | | } |
| | | |
| | | try { |
| | | // 这里需要根据实际API返回结构调整 |
| | | // 假设API返回的格式是{ "output": "消息内容" } |
| | | JsonObject jsonObject = gson.fromJson(data, JsonObject.class); |
| | | String title = null; |
| | | if (jsonObject != null && jsonObject.has("data")) { |
| | | JsonElement dataElement = jsonObject.get("data"); |
| | | if (dataElement != null && !dataElement.isJsonNull()) { |
| | | JsonObject dataObject = dataElement.getAsJsonObject(); |
| | | if (dataObject != null && dataObject.has("title")) { |
| | | JsonElement titleElement = dataObject.get("title"); |
| | | if (titleElement != null && !titleElement.isJsonNull()) { |
| | | title = titleElement.getAsString(); |
| | | } |
| | | } |
| | | if (dataObject != null && dataObject.has("text")) { |
| | | JsonElement titleElement = dataObject.get("text"); |
| | | if (titleElement != null && !titleElement.isJsonNull()) { |
| | | title = titleElement.getAsString(); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | |
| | | // 创建消息对象并发送给前端 |
| | | AiMessage aiMessage = new AiMessage(); |
| | | aiMessage.setContent(title); |
| | | System.out.println(gson.fromJson(data, JsonObject.class)); |
| | | // 将消息发送给前端 |
| | | emitter.send(JSON.toJSONString(aiMessage)); |
| | | |
| | | } catch (Exception e) { |
| | | // 记录解析错误但继续处理后续数据 |
| | | System.err.println("解析响应数据时出错: " + e.getMessage()); |
| | | emitter.completeWithError(e); |
| | | } |
| | | } |
| | | } |
| | | |
| | | // 所有数据处理完毕,发送完成信号 |
| | | emitter.send(SseEmitter.event().name("complete")); |
| | | emitter.complete(); |
| | | |
| | | } catch (IOException e) { |
| | | emitter.completeWithError(e); |
| | | emitter.complete(); |
| | | } catch (Exception e) { |
| | | // 处理其他异常 |
| | | emitter.completeWithError(e); |
| | | } |
| | | } |
| | | }); |
| | | |
| | | return future; |
| | | } |
| | | |
| | | public String fileUpload(String userId, String filePath){ |
| | | // 要上传的文件路径,替换为实际的文件路径 |
| | | // String filePath = "C:\\Users\\admin\\Desktop\\国务院政策文件库.xlsx"; |
| | | // 用户标识,替换为实际的用户标识,要和发送消息接口的 user 保持一致 |
| | | String user = userId; |
| | | |
| | | File file = new File(filePath); |
| | | if (!file.exists()) { |
| | | System.out.println("文件不存在:" + filePath); |
| | | return user; |
| | | } |
| | | |
| | | OkHttpClient client = new OkHttpClient(); |
| | | |
| | | // 构建 multipart/form-data 请求体 |
| | | RequestBody requestBody = new MultipartBody.Builder() |
| | | .setType(MultipartBody.FORM) |
| | | .addFormDataPart("file", file.getName(), |
| | | RequestBody.create(MediaType.parse("application/octet-stream"), file)) |
| | | .addFormDataPart("user", user) |
| | | .build(); |
| | | |
| | | Request request = new Request.Builder() |
| | | .url(apiUrl) |
| | | .post(requestBody) |
| | | .header("Authorization", apiKey) |
| | | .build(); |
| | | |
| | | try (Response response = client.newCall(request).execute()) { |
| | | if (!response.isSuccessful()) { |
| | | System.out.println("请求失败,状态码:" + response.code()); |
| | | return user; |
| | | } |
| | | String responseBody = response.body().string(); |
| | | System.out.println("上传结果:" + responseBody); |
| | | return responseBody; |
| | | } catch (IOException e) { |
| | | e.printStackTrace(); |
| | | return e.getMessage(); |
| | | } |
| | | } |
| | | |
| | | // 流式聊天方法 - 直接集成SseEmitter |
| | |
| | | emitter.completeWithError(e); |
| | | future.completeExceptionally(e); |
| | | }); |
| | | |
| | | // 发送异步请求 |
| | | client.newCall(request).enqueue(new Callback() { |
| | | @Override |
| | |
| | | // System.out.println("end"); |
| | | }else{ |
| | | String context = messageObj.get("answer").getAsString(); |
| | | // System.out.println(context); |
| | | if (context != null) { |
| | | // // 只移除HTML标签,保留Markdown特殊字符 |
| | | // context = context.replaceAll("(?i)<[^>]*>", ""); |
| | | context = context.replaceFirst("Thinking...", ""); |
| | | if(context.startsWith("<details")){ |
| | | context = context.replaceAll("(?i)<[^>]*>", ""); |
| | | } |
| | | } |
| | | sb.append(context); |
| | | aiMessage.setContent(context); |
| | | aiMessageResponse.setMessage(aiMessage); |
| | | if (StringUtil.hasText(messageObj.get("answer").getAsString())) { |
| | | // System.out.println(aiMessage); |
| | | if(aiMessage.getContent().startsWith("</details>")){ |
| | | aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "</details>"+"\n\n")); |
| | | if (!messageObj.get("answer").getAsString().isEmpty()) { |
| | | if(!blean && aiMessage.getContent().startsWith("</details>")){ |
| | | blean = true; |
| | | aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "\n\n")); |
| | | } |
| | | // 发送消息片段给前端 |
| | | emitter.send(JSON.toJSONString(aiMessage)); |
| | | if(blean){ |
| | | sb.append(aiMessage.getContent()); |
| | | // System.out.println(aiMessage); |
| | | // 发送消息片段给前端 |
| | | emitter.send(JSON.toJSONString(aiMessage)); |
| | | } |
| | | } |
| | | } |
| | | |
| | |
| | | } catch (Exception e) { |
| | | // 处理其他异常 |
| | | emitter.completeWithError(e); |
| | | } |
| | | } |
| | | |
| | | // @Override |
| | | public void onResponses(Call call, Response response) { |
| | | // try (ResponseBody responseBody = response.body()) { |
| | | // if (!response.isSuccessful() || responseBody == null) { |
| | | // IOException e = new IOException("Unexpected code " + response); |
| | | // emitter.completeWithError(e); |
| | | // future.completeExceptionally(e); |
| | | // return; |
| | | // } |
| | | //// String rawResponse = responseBody.string(); // 打印原始响应 |
| | | //// System.out.println("Dify原始响应: " + rawResponse); |
| | | // |
| | | // BufferedSource source = responseBody.source(); |
| | | // String line; |
| | | // |
| | | // // 处理流式响应 |
| | | // while ((line = source.readUtf8Line()) != null) { |
| | | // if (line.startsWith("data: ")) { |
| | | // String data = line.substring(6); // 移除"data: "前缀 |
| | | // |
| | | // // 跳过空数据 |
| | | // if (data.trim().equals("[DONE]")) { |
| | | // emitter.complete(); |
| | | // future.complete(null); |
| | | // break; |
| | | // } |
| | | // |
| | | // try { |
| | | // // 解析JSON响应 |
| | | // JsonObject jsonResponse = gson.fromJson(data, JsonObject.class); |
| | | // // 发送消息内容到客户端 |
| | | // if (jsonResponse.has("message") && jsonResponse.get("message").isJsonObject()) { |
| | | // String content = jsonResponse.getAsJsonObject("message").get("content").getAsString(); |
| | | // emitter.send(SseEmitter.event() |
| | | // .data(content) |
| | | // .name("message") |
| | | // ); |
| | | // } |
| | | // |
| | | // // 处理函数调用 |
| | | // if (jsonResponse.has("function_call") && jsonResponse.get("function_call").isJsonObject()) { |
| | | // emitter.send(SseEmitter.event() |
| | | // .data(jsonResponse.get("function_call")) |
| | | // .name("function_call") |
| | | // ); |
| | | // } |
| | | // |
| | | // } catch (Exception e) { |
| | | // emitter.completeWithError(e); |
| | | // future.completeExceptionally(e); |
| | | // break; |
| | | // } |
| | | // } |
| | | // } |
| | | try (ResponseBody responseBody = response.body()) { |
| | | if (!response.isSuccessful()) { |
| | | emitter.completeWithError(new IOException("API错误")); |
| | | return; |
| | | } |
| | | StringBuilder fullAnswer = new StringBuilder(); |
| | | BufferedSource source = responseBody.source(); |
| | | String line; |
| | | while ((line = source.readUtf8Line()) != null) { |
| | | if (line.startsWith("data: ")) { |
| | | String data = line.substring(6).trim(); |
| | | |
| | | if (data.equals("[DONE]")) { |
| | | // 发送完整回答并结束流 |
| | | emitter.send(fullAnswer.toString()); |
| | | emitter.complete(); |
| | | break; |
| | | } |
| | | |
| | | try { |
| | | JsonObject json = gson.fromJson(data, JsonObject.class); |
| | | System.out.println(json); |
| | | String answerFragment = json.get("answer").getAsString(); |
| | | // 过滤无效字符(如️⃣、\n) |
| | | String cleanedFragment = answerFragment.replaceAll("[^\u4e00-\u9fa50-9a-zA-Z\\s\\p{Punct}]", ""); |
| | | fullAnswer.append(cleanedFragment); // 拼接碎片 |
| | | |
| | | // 可选:每拼接一部分发送给前端(需前端支持增量渲染) |
| | | // emitter.send(SseEmitter.event().data(cleanedFragment)); |
| | | |
| | | } catch (Exception e) { |
| | | emitter.sendAndComplete("解析错误: " + e.getMessage()); |
| | | } |
| | | } |
| | | } |
| | | } catch (Exception e) { |
| | | emitter.completeWithError(e); |
| | | future.completeExceptionally(e); |
| | | } |
| | | } |
| | | }); |