package tech.aiflowy.ai.config;
|
|
import com.agentsflex.core.message.AiMessage;
|
import com.agentsflex.core.prompt.HistoriesPrompt;
|
import com.alibaba.fastjson.JSON;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.google.gson.*;
|
import com.mybatisflex.core.query.QueryWrapper;
|
import okhttp3.*;
|
import okio.BufferedSource;
|
import org.springframework.web.multipart.MultipartFile;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import tech.aiflowy.ai.entity.AiBotMessage;
|
import tech.aiflowy.ai.service.AiBotMessageService;
|
import tech.aiflowy.common.ai.MySseEmitter;
|
|
import java.io.File;
|
import java.io.IOException;
|
import java.math.BigInteger;
|
import java.util.*;
|
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.TimeUnit;
|
|
public class DifyStreamClient {
|
private final OkHttpClient client;
|
private final String apiUrl;
|
private final String apiKey;
|
private final Gson gson;
|
private String prompt;
|
private AiBotMessageService aiBotMessageService;
|
boolean blean = false;
|
|
public DifyStreamClient(String apiUrl, String apiKey, AiBotMessageService aiBotMessageService) {
|
this.apiUrl = apiUrl;
|
this.apiKey = apiKey;
|
this.gson = new GsonBuilder().setPrettyPrinting().create();
|
this.client = new OkHttpClient.Builder().connectTimeout(60, TimeUnit.SECONDS) // 连接超时
|
.readTimeout(120, TimeUnit.SECONDS).build();
|
this.aiBotMessageService = aiBotMessageService;
|
}
|
|
// 在DifyStreamClient类中添加以下方法
|
public CompletableFuture<Void> runWorkflow(Map<String, Object> inputs, String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) {
|
// 构建请求JSON
|
JsonObject requestBody = new JsonObject();
|
// 添加inputs参数
|
JsonObject inputsJson = new JsonObject();
|
if (inputs != null) {
|
for (Map.Entry<String, Object> entry : inputs.entrySet()) {
|
if (entry.getValue() instanceof String) {
|
inputsJson.addProperty(entry.getKey(), (String) entry.getValue());
|
} else if (entry.getValue() instanceof Number) {
|
inputsJson.addProperty(entry.getKey(), (Number) entry.getValue());
|
} else if (entry.getValue() instanceof Boolean) {
|
inputsJson.addProperty(entry.getKey(), (Boolean) entry.getValue());
|
} else {
|
// 对于复杂对象,转换为JSON字符串
|
inputsJson.add(entry.getKey(), gson.toJsonTree(entry.getValue()));
|
}
|
}
|
}
|
requestBody.add("inputs", inputsJson);
|
// 设置响应模式和用户ID
|
requestBody.addProperty("response_mode", "streaming");
|
requestBody.addProperty("user", userId);
|
|
// System.out.println(requestBody+"==============================================================================================");
|
// 创建请求
|
RequestBody body = RequestBody.create(
|
gson.toJson(requestBody),
|
MediaType.parse("application/json; charset=utf-8")
|
);
|
|
Request request = new Request.Builder()
|
.url(apiUrl)
|
.post(body)
|
.header("Authorization", apiKey)
|
.header("Content-Type", "application/json")
|
.build();
|
|
CompletableFuture<Void> future = new CompletableFuture<>();
|
|
// 设置SseEmitter生命周期回调
|
emitter.onTimeout(() -> {
|
System.out.println("SSE连接超时");
|
emitter.complete();
|
future.complete(null);
|
});
|
|
emitter.onCompletion(() -> {
|
System.out.println("SSE连接已完成");
|
future.complete(null);
|
});
|
|
emitter.onError(e -> {
|
System.out.println("SSE连接错误: " + e.getMessage());
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
});
|
|
// 发送异步请求
|
client.newCall(request).enqueue(new Callback() {
|
@Override
|
public void onFailure(Call call, IOException e) {
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
}
|
|
@Override
|
public void onResponse(Call call, Response response) {
|
StringBuffer sb = new StringBuffer();
|
try (ResponseBody responseBody = response.body()) {
|
if (!response.isSuccessful()) {
|
emitter.completeWithError(new IOException("API错误: " + response.code()));
|
return;
|
}
|
|
// 使用BufferedSource逐行读取响应内容
|
BufferedSource source = responseBody.source();
|
String line;
|
|
while ((line = source.readUtf8Line()) != null) {
|
if (line.startsWith("data: ")) {
|
String data = line.substring(6).trim();
|
|
// 忽略空数据或结束标记
|
if (data.isEmpty() || data.equals("[DONE]")) {
|
continue;
|
}
|
|
try {
|
// 这里需要根据实际API返回结构调整
|
// 假设API返回的格式是{ "output": "消息内容" }
|
JsonObject jsonObject = gson.fromJson(data, JsonObject.class);
|
String title = null;
|
// System.out.println(jsonObject);
|
if (jsonObject != null && jsonObject.has("data")) {
|
JsonElement dataElement = jsonObject.get("data");
|
if (dataElement != null && !dataElement.isJsonNull()) {
|
JsonObject dataObject = dataElement.getAsJsonObject();
|
if (dataObject != null && dataObject.has("node_type")) {
|
continue;
|
}
|
if (dataObject != null && dataObject.has("title")) {
|
JsonElement titleElement = dataObject.get("title");
|
if (titleElement != null && !titleElement.isJsonNull()) {
|
title = titleElement.getAsString();
|
}
|
}
|
else if (dataObject != null && dataObject.has("text")) {
|
JsonElement titleElement = dataObject.get("text");
|
if (titleElement != null && !titleElement.isJsonNull()) {
|
title = titleElement.getAsString();
|
}
|
}else if (dataObject != null && dataObject.has("outputs")) {
|
JsonElement titleElement = dataObject.get("outputs");
|
if (titleElement != null && !titleElement.isJsonNull()) {
|
|
AiBotMessage aiBotMessage = new AiBotMessage();
|
aiBotMessage.setBotId(botId);
|
aiBotMessage.setSessionId(sessionId);
|
aiBotMessage.setAccountId(new BigInteger(userId));
|
aiBotMessage.setRole("assistant");
|
aiBotMessage.setContent(sb.toString());
|
aiBotMessage.setCreated(new Date());
|
aiBotMessage.setIsExternalMsg(1);
|
aiBotMessageService.save(aiBotMessage);
|
// dataObject = titleElement.getAsJsonObject();
|
// if (dataObject != null && dataObject.has("text")) {
|
// titleElement = dataObject.get("text");
|
// if (titleElement != null && !titleElement.isJsonNull()) {
|
// title = titleElement.getAsString();
|
// }
|
// }else if (dataObject != null && dataObject.has("data")) {
|
// titleElement = dataObject.get("data");
|
// if (titleElement != null && !titleElement.isJsonNull()) {
|
// title = titleElement.getAsString();
|
// }
|
// }
|
}
|
}
|
}
|
}
|
|
// 创建消息对象并发送给前端
|
AiMessage aiMessage = new AiMessage();
|
aiMessage.setContent(title);
|
System.out.println(gson.fromJson(data, JsonObject.class));
|
sb.append(aiMessage.getContent());
|
// 将消息发送给前端
|
if (aiMessage.getContent() != null){
|
emitter.send(JSON.toJSONString(aiMessage));
|
}
|
|
} catch (Exception e) {
|
// 记录解析错误但继续处理后续数据
|
System.err.println("解析响应数据时出错: " + e.getMessage());
|
emitter.completeWithError(e);
|
}
|
}
|
}
|
|
// 所有数据处理完毕,发送完成信号
|
emitter.send(SseEmitter.event().name("complete"));
|
emitter.complete();
|
|
} catch (IOException e) {
|
emitter.completeWithError(e);
|
emitter.complete();
|
} catch (Exception e) {
|
// 处理其他异常
|
emitter.completeWithError(e);
|
}
|
}
|
});
|
|
return future;
|
}
|
|
public String fileUpload(String userId, MultipartFile file) {
|
// 用户标识,要和发送消息接口的 user 保持一致
|
String user = userId;
|
|
// 判断文件是否为空
|
if (file.isEmpty()) {
|
System.out.println("上传文件为空");
|
return "上传文件为空";
|
}
|
|
OkHttpClient client = new OkHttpClient();
|
|
// 构建 multipart/form-data 请求体
|
MultipartBody.Builder requestBodyBuilder = null;
|
try {
|
requestBodyBuilder = new MultipartBody.Builder()
|
.setType(MultipartBody.FORM)
|
// 添加上传文件,这里直接用 MultipartFile 的字节数组构建请求体
|
.addFormDataPart("file", file.getOriginalFilename(),
|
RequestBody.create(MediaType.parse("application/octet-stream"), file.getBytes()))
|
.addFormDataPart("user", user);
|
} catch (IOException e) {
|
throw new RuntimeException(e);
|
}
|
|
RequestBody requestBody = requestBodyBuilder.build();
|
|
Request request = new Request.Builder()
|
.url(apiUrl) // 替换为实际的接口地址常量
|
.post(requestBody)
|
.header("Authorization", apiKey) // 替换为实际的授权密钥常量
|
.build();
|
|
try (Response response = client.newCall(request).execute()) {
|
if (!response.isSuccessful()) {
|
System.out.println("请求失败,状态码:" + response.code());
|
return user;
|
}
|
String responseBody = response.body().string();
|
System.out.println("上传结果:" + responseBody);
|
return responseBody;
|
} catch (IOException e) {
|
e.printStackTrace();
|
return "文件上传失败:" + e.getMessage();
|
}
|
}
|
|
// 流式聊天方法 - 直接集成SseEmitter
|
public CompletableFuture<Void> chatStream(String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) {
|
prompt = message;
|
QueryWrapper qw = new QueryWrapper();
|
qw.eq(AiBotMessage::getSessionId, sessionId)
|
.orderBy(AiBotMessage::getId,false)
|
.limit(6);
|
|
List<AiBotMessage> history = aiBotMessageService.list(qw);
|
// System.out.println("======================history==================\n");
|
// for (AiBotMessage aiBotMessage : history) {
|
// System.out.println(aiBotMessage.toMessage());
|
// }
|
// System.out.println("\n======================history==================");
|
|
// 构建请求JSON
|
JsonObject requestBody = new JsonObject();
|
requestBody.add("inputs", new JsonObject());
|
requestBody.addProperty("query", message);
|
requestBody.addProperty("response_mode", "streaming");
|
requestBody.addProperty("conversation_id", "");
|
requestBody.addProperty("user", userId);
|
// requestBody.add("files", new JsonArray());
|
// 添加历史对话信息
|
JsonArray historyArray = new JsonArray();
|
// for (AiBotMessage msg : history) {
|
// historyArray.add(String.valueOf(msg));
|
// }
|
// requestBody.add("history", historyArray);
|
|
RequestBody body = RequestBody.create(
|
gson.toJson(requestBody),
|
MediaType.parse("application/json; charset=utf-8")
|
);
|
|
Request request = new Request.Builder()
|
.url(apiUrl)
|
.post(body)
|
.header("Authorization", apiKey)
|
.header("Content-Type", "application/json")
|
.build();
|
|
CompletableFuture<Void> future = new CompletableFuture<>();
|
|
// 设置SseEmitter生命周期回调
|
emitter.onTimeout(() -> {
|
System.out.println("SSE连接超时");
|
emitter.complete();
|
future.complete(null);
|
});
|
|
emitter.onCompletion(() -> {
|
System.out.println("SSE连接已完成");
|
future.complete(null);
|
});
|
|
emitter.onError(e -> {
|
System.out.println("SSE连接错误: " + e.getMessage());
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
});
|
// 发送异步请求
|
client.newCall(request).enqueue(new Callback() {
|
@Override
|
public void onFailure(Call call, IOException e) {
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
}
|
|
@Override
|
public void onResponse(Call call, Response response) {
|
AiMessage aiMessage = new AiMessage();
|
com.agentsflex.core.llm.response.AiMessageResponse aiMessageResponse
|
= new com.agentsflex.core.llm.response.AiMessageResponse(new HistoriesPrompt(), response.message(), aiMessage);
|
try (ResponseBody responseBody = response.body()) {
|
if (!response.isSuccessful()) {
|
emitter.completeWithError(new IOException("API错误: " + response.code()));
|
return;
|
}
|
|
// 使用BufferedSource逐行读取响应内容
|
BufferedSource source = responseBody.source();
|
String line;
|
StringBuffer sb = new StringBuffer();
|
// 标记是否为第一条有效数据(用于处理某些API的特殊格式)
|
boolean isFirstData = true;
|
|
while ((line = source.readUtf8Line()) != null) {
|
if (line.startsWith("data: ")) {
|
String data = line.substring(6).trim();
|
|
// 忽略空数据或结束标记
|
if (data.isEmpty() || data.equals("[DONE]")) {
|
continue;
|
}
|
|
try {
|
// 根据实际API响应结构提取消息内容
|
// 这里需要根据具体API调整路径
|
JsonObject messageObj = gson.fromJson(data, JsonObject.class);
|
// System.out.println(messageObj);
|
if(messageObj.get("event").getAsString().equals("message_end")){
|
AiBotMessage aiBotMessage = new AiBotMessage();
|
aiBotMessage.setBotId(botId);
|
aiBotMessage.setSessionId(sessionId);
|
aiBotMessage.setAccountId(new BigInteger(userId));
|
aiBotMessage.setRole("assistant");
|
aiBotMessage.setContent(sb.toString());
|
aiBotMessage.setCreated(new Date());
|
aiBotMessage.setIsExternalMsg(1);
|
aiBotMessageService.save(aiBotMessage);
|
// System.out.println("end");
|
}else{
|
String context = messageObj.get("answer").getAsString();
|
// System.out.println(context);
|
if (context != null) {
|
// // 只移除HTML标签,保留Markdown特殊字符
|
// context = context.replaceAll("(?i)<[^>]*>", "");
|
context = context.replaceFirst("Thinking...", "");
|
if(context.startsWith("<details")){
|
context = context.replaceAll("(?i)<[^>]*>", "");
|
}
|
}
|
aiMessage.setContent(context);
|
aiMessageResponse.setMessage(aiMessage);
|
if (!messageObj.get("answer").getAsString().isEmpty()) {
|
if(!blean && aiMessage.getContent().startsWith("</details>")){
|
blean = true;
|
aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "\n\n"));
|
}
|
if(blean){
|
sb.append(aiMessage.getContent());
|
// System.out.println(aiMessage);
|
// 发送消息片段给前端
|
emitter.send(JSON.toJSONString(aiMessage));
|
}
|
}
|
}
|
|
// 重置第一条数据标记
|
isFirstData = false;
|
|
} catch (Exception e) {
|
// 记录解析错误但继续处理后续数据
|
emitter.completeWithError(e);
|
}
|
}
|
}
|
|
// 所有数据处理完毕,发送完成信号
|
emitter.send(SseEmitter.event().name("complete"));
|
emitter.complete();
|
|
} catch (IOException e) {
|
emitter.completeWithError(e);
|
emitter.complete();
|
} catch (Exception e) {
|
// 处理其他异常
|
emitter.completeWithError(e);
|
}
|
}
|
});
|
|
return future;
|
}
|
|
// 数据模型类
|
public interface StreamResponseListener {
|
void onMessage(ChatContext context, AiMessageResponse response);
|
void onStop(ChatContext context);
|
void onFailure(ChatContext context, Throwable throwable);
|
}
|
|
public static class ChatContext {
|
// 上下文信息
|
}
|
|
public static class AiMessageResponse {
|
private Message message;
|
private JsonArray functionCallers;
|
|
public Message getMessage() { return message; }
|
public void setMessage(Message message) { this.message = message; }
|
public JsonArray getFunctionCallers() { return functionCallers; }
|
public void setFunctionCallers(JsonArray functionCallers) { this.functionCallers = functionCallers; }
|
}
|
|
public static class Message {
|
private String content;
|
|
public String getContent() { return content; }
|
public void setContent(String content) { this.content = content; }
|
}
|
}
|