admin
2025-06-09 4eb46966002c6ca24cbb8cc8b519a05610e81649
aiflowy-modules/aiflowy-module-ai/src/main/java/tech/aiflowy/ai/config/DifyStreamClient.java
@@ -1,31 +1,20 @@
package tech.aiflowy.ai.config;
import com.agentsflex.core.llm.response.AiMessageResponse;
import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.message.Message;
import com.agentsflex.core.prompt.HistoriesPrompt;
import com.agentsflex.core.prompt.Prompt;
import com.alibaba.fastjson.JSON;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.*;
import com.mybatisflex.core.query.QueryWrapper;
import okhttp3.*;
import okio.BufferedSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import tech.aiflowy.ai.entity.AiBotConversationMessage;
import tech.aiflowy.ai.entity.AiBotMessage;
import tech.aiflowy.ai.service.AiBotConversationMessageService;
import tech.aiflowy.ai.service.AiBotMessageService;
import tech.aiflowy.ai.service.impl.AiBotMessageServiceImpl;
import tech.aiflowy.common.ai.MySseEmitter;
import tech.aiflowy.common.util.StringUtil;
import tech.aiflowy.common.web.controller.BaseCurdController;
import javax.annotation.Resource;
import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.util.*;
@@ -37,6 +26,7 @@
    private final Gson gson;
    private String prompt;
    private AiBotMessageService aiBotMessageService;
    boolean blean = false;
    public DifyStreamClient(String apiUrl, String apiKey, AiBotMessageService aiBotMessageService) {
        this.apiUrl = apiUrl;
@@ -44,6 +34,200 @@
        this.gson = new GsonBuilder().setPrettyPrinting().create();
        this.client = new OkHttpClient.Builder().build();
        this.aiBotMessageService = aiBotMessageService;
    }
    // 在DifyStreamClient类中添加以下方法
    public CompletableFuture<Void> runWorkflow(Map<String, Object> inputs, String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) {
        // 构建请求JSON
        JsonObject requestBody = new JsonObject();
        // 添加inputs参数
        JsonObject inputsJson = new JsonObject();
        if (inputs != null) {
            for (Map.Entry<String, Object> entry : inputs.entrySet()) {
                if (entry.getValue() instanceof String) {
                    inputsJson.addProperty(entry.getKey(), (String) entry.getValue());
                } else if (entry.getValue() instanceof Number) {
                    inputsJson.addProperty(entry.getKey(), (Number) entry.getValue());
                } else if (entry.getValue() instanceof Boolean) {
                    inputsJson.addProperty(entry.getKey(), (Boolean) entry.getValue());
                } else {
                    // 对于复杂对象,转换为JSON字符串
                    inputsJson.add(entry.getKey(), gson.toJsonTree(entry.getValue()));
                }
            }
        }
        requestBody.add("inputs", inputsJson);
        // 设置响应模式和用户ID
        requestBody.addProperty("response_mode", "streaming");
        requestBody.addProperty("user", userId);
        // 创建请求
        RequestBody body = RequestBody.create(
                gson.toJson(requestBody),
                MediaType.parse("application/json; charset=utf-8")
        );
        Request request = new Request.Builder()
                .url(apiUrl)
                .post(body)
                .header("Authorization", apiKey)
                .header("Content-Type", "application/json")
                .build();
        CompletableFuture<Void> future = new CompletableFuture<>();
        // 设置SseEmitter生命周期回调
        emitter.onTimeout(() -> {
            System.out.println("SSE连接超时");
            emitter.complete();
            future.complete(null);
        });
        emitter.onCompletion(() -> {
            System.out.println("SSE连接已完成");
            future.complete(null);
        });
        emitter.onError(e -> {
            System.out.println("SSE连接错误: " + e.getMessage());
            emitter.completeWithError(e);
            future.completeExceptionally(e);
        });
        // 发送异步请求
        client.newCall(request).enqueue(new Callback() {
            @Override
            public void onFailure(Call call, IOException e) {
                emitter.completeWithError(e);
                future.completeExceptionally(e);
            }
            @Override
            public void onResponse(Call call, Response response) {
                try (ResponseBody responseBody = response.body()) {
                    if (!response.isSuccessful()) {
                        emitter.completeWithError(new IOException("API错误: " + response.code()));
                        return;
                    }
                    // 使用BufferedSource逐行读取响应内容
                    BufferedSource source = responseBody.source();
                    String line;
                    while ((line = source.readUtf8Line()) != null) {
                        if (line.startsWith("data: ")) {
                            String data = line.substring(6).trim();
                            // 忽略空数据或结束标记
                            if (data.isEmpty() || data.equals("[DONE]")) {
                                continue;
                            }
                            try {
                                // 这里需要根据实际API返回结构调整
                                // 假设API返回的格式是{ "output": "消息内容" }
                                JsonObject jsonObject = gson.fromJson(data, JsonObject.class);
                                String title = null;
                                if (jsonObject != null && jsonObject.has("data")) {
                                    JsonElement dataElement = jsonObject.get("data");
                                    if (dataElement != null && !dataElement.isJsonNull()) {
                                        JsonObject dataObject = dataElement.getAsJsonObject();
                                        if (dataObject != null && dataObject.has("node_type")) {
                                            continue;
                                        }
                                        if (dataObject != null && dataObject.has("title")) {
                                            JsonElement titleElement = dataObject.get("title");
                                            if (titleElement != null && !titleElement.isJsonNull()) {
                                                title = titleElement.getAsString();
                                            }
                                        }
                                        if (dataObject != null && dataObject.has("text")) {
                                            JsonElement titleElement = dataObject.get("text");
                                            if (titleElement != null && !titleElement.isJsonNull()) {
                                                title = titleElement.getAsString();
                                            }
                                        }
                                    }
                                }
                                // 创建消息对象并发送给前端
                                AiMessage aiMessage = new AiMessage();
                                aiMessage.setContent(title);
                                System.out.println(gson.fromJson(data, JsonObject.class));
                                // 将消息发送给前端
                                emitter.send(JSON.toJSONString(aiMessage));
                            } catch (Exception e) {
                                // 记录解析错误但继续处理后续数据
                                System.err.println("解析响应数据时出错: " + e.getMessage());
                                emitter.completeWithError(e);
                            }
                        }
                    }
                    // 所有数据处理完毕,发送完成信号
                    emitter.send(SseEmitter.event().name("complete"));
                    emitter.complete();
                } catch (IOException e) {
                    emitter.completeWithError(e);
                    emitter.complete();
                } catch (Exception e) {
                    // 处理其他异常
                    emitter.completeWithError(e);
                }
            }
        });
        return future;
    }
    public String fileUpload(String userId, MultipartFile file) {
        // 用户标识,要和发送消息接口的 user 保持一致
        String user = userId;
        // 判断文件是否为空
        if (file.isEmpty()) {
            System.out.println("上传文件为空");
            return "上传文件为空";
        }
        OkHttpClient client = new OkHttpClient();
        // 构建 multipart/form-data 请求体
        MultipartBody.Builder requestBodyBuilder = null;
        try {
            requestBodyBuilder = new MultipartBody.Builder()
                    .setType(MultipartBody.FORM)
                    // 添加上传文件,这里直接用 MultipartFile 的字节数组构建请求体
                    .addFormDataPart("file", file.getOriginalFilename(),
                            RequestBody.create(MediaType.parse("application/octet-stream"), file.getBytes()))
                    .addFormDataPart("user", user);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        RequestBody requestBody = requestBodyBuilder.build();
        Request request = new Request.Builder()
                .url(apiUrl)  // 替换为实际的接口地址常量
                .post(requestBody)
                .header("Authorization", apiKey)  // 替换为实际的授权密钥常量
                .build();
        try (Response response = client.newCall(request).execute()) {
            if (!response.isSuccessful()) {
                System.out.println("请求失败,状态码:" + response.code());
                return user;
            }
            String responseBody = response.body().string();
            System.out.println("上传结果:" + responseBody);
            return responseBody;
        } catch (IOException e) {
            e.printStackTrace();
            return "文件上传失败:" + e.getMessage();
        }
    }
    // 流式聊天方法 - 直接集成SseEmitter
@@ -107,7 +291,6 @@
            emitter.completeWithError(e);
            future.completeExceptionally(e);
        });
        // 发送异步请求
        client.newCall(request).enqueue(new Callback() {
            @Override
@@ -161,21 +344,28 @@
//                                    System.out.println("end");
                                }else{
                                    String context = messageObj.get("answer").getAsString();
//                                    System.out.println(context);
                                    if (context != null) {
//                                    // 只移除HTML标签,保留Markdown特殊字符
//                                        context = context.replaceAll("(?i)<[^>]*>", "");
                                        context = context.replaceFirst("Thinking...", "");
                                        if(context.startsWith("<details")){
                                            context = context.replaceAll("(?i)<[^>]*>", "");
                                        }
                                    }
                                    sb.append(context);
                                    aiMessage.setContent(context);
                                    aiMessageResponse.setMessage(aiMessage);
                                    if (StringUtil.hasText(messageObj.get("answer").getAsString())) {
//                                        System.out.println(aiMessage);
                                        if(aiMessage.getContent().startsWith("</details>")){
                                            aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "</details>"+"\n\n"));
                                    if (!messageObj.get("answer").getAsString().isEmpty()) {
                                        if(!blean && aiMessage.getContent().startsWith("</details>")){
                                            blean = true;
                                            aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "\n\n"));
                                        }
                                        // 发送消息片段给前端
                                        emitter.send(JSON.toJSONString(aiMessage));
                                        if(blean){
                                            sb.append(aiMessage.getContent());
//                                            System.out.println(aiMessage);
                                            // 发送消息片段给前端
                                            emitter.send(JSON.toJSONString(aiMessage));
                                        }
                                    }
                                }
@@ -199,101 +389,6 @@
                } catch (Exception e) {
                    // 处理其他异常
                    emitter.completeWithError(e);
                }
            }
//            @Override
            public void onResponses(Call call, Response response) {
//                try (ResponseBody responseBody = response.body()) {
//                    if (!response.isSuccessful() || responseBody == null) {
//                        IOException e = new IOException("Unexpected code " + response);
//                        emitter.completeWithError(e);
//                        future.completeExceptionally(e);
//                        return;
//                    }
////                    String rawResponse = responseBody.string(); // 打印原始响应
////                    System.out.println("Dify原始响应: " + rawResponse);
//
//                    BufferedSource source = responseBody.source();
//                    String line;
//
//                    // 处理流式响应
//                    while ((line = source.readUtf8Line()) != null) {
//                        if (line.startsWith("data: ")) {
//                            String data = line.substring(6); // 移除"data: "前缀
//
//                            // 跳过空数据
//                            if (data.trim().equals("[DONE]")) {
//                                emitter.complete();
//                                future.complete(null);
//                                break;
//                            }
//
//                            try {
//                                // 解析JSON响应
//                                JsonObject jsonResponse = gson.fromJson(data, JsonObject.class);
//                                // 发送消息内容到客户端
//                                if (jsonResponse.has("message") && jsonResponse.get("message").isJsonObject()) {
//                                    String content = jsonResponse.getAsJsonObject("message").get("content").getAsString();
//                                    emitter.send(SseEmitter.event()
//                                            .data(content)
//                                            .name("message")
//                                    );
//                                }
//
//                                // 处理函数调用
//                                if (jsonResponse.has("function_call") && jsonResponse.get("function_call").isJsonObject()) {
//                                    emitter.send(SseEmitter.event()
//                                            .data(jsonResponse.get("function_call"))
//                                            .name("function_call")
//                                    );
//                                }
//
//                            } catch (Exception e) {
//                                emitter.completeWithError(e);
//                                future.completeExceptionally(e);
//                                break;
//                            }
//                        }
//                    }
                try (ResponseBody responseBody = response.body()) {
                    if (!response.isSuccessful()) {
                        emitter.completeWithError(new IOException("API错误"));
                        return;
                    }
                    StringBuilder fullAnswer = new StringBuilder();
                    BufferedSource source = responseBody.source();
                    String line;
                    while ((line = source.readUtf8Line()) != null) {
                        if (line.startsWith("data: ")) {
                            String data = line.substring(6).trim();
                            if (data.equals("[DONE]")) {
                                // 发送完整回答并结束流
                                emitter.send(fullAnswer.toString());
                                emitter.complete();
                                break;
                            }
                            try {
                                JsonObject json = gson.fromJson(data, JsonObject.class);
                                System.out.println(json);
                                String answerFragment = json.get("answer").getAsString();
                                // 过滤无效字符(如️⃣、\n)
                                String cleanedFragment = answerFragment.replaceAll("[^\u4e00-\u9fa50-9a-zA-Z\\s\\p{Punct}]", "");
                                fullAnswer.append(cleanedFragment); // 拼接碎片
                                // 可选:每拼接一部分发送给前端(需前端支持增量渲染)
                                // emitter.send(SseEmitter.event().data(cleanedFragment));
                            } catch (Exception e) {
                                emitter.sendAndComplete("解析错误: " + e.getMessage());
                            }
                        }
                    }
                } catch (Exception e) {
                    emitter.completeWithError(e);
                    future.completeExceptionally(e);
                }
            }
        });