package tech.aiflowy.ai.config;
|
|
import com.agentsflex.core.llm.response.AiMessageResponse;
|
import com.agentsflex.core.message.AiMessage;
|
import com.agentsflex.core.message.Message;
|
import com.agentsflex.core.prompt.HistoriesPrompt;
|
import com.agentsflex.core.prompt.Prompt;
|
import com.alibaba.fastjson.JSON;
|
import com.google.gson.Gson;
|
import com.google.gson.GsonBuilder;
|
import com.google.gson.JsonArray;
|
import com.google.gson.JsonObject;
|
import com.mybatisflex.core.query.QueryWrapper;
|
import okhttp3.*;
|
import okio.BufferedSource;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import tech.aiflowy.ai.entity.AiBotConversationMessage;
|
import tech.aiflowy.ai.entity.AiBotMessage;
|
import tech.aiflowy.ai.service.AiBotConversationMessageService;
|
import tech.aiflowy.ai.service.AiBotMessageService;
|
import tech.aiflowy.ai.service.impl.AiBotMessageServiceImpl;
|
import tech.aiflowy.common.ai.MySseEmitter;
|
import tech.aiflowy.common.util.StringUtil;
|
import tech.aiflowy.common.web.controller.BaseCurdController;
|
|
import javax.annotation.Resource;
|
import java.io.IOException;
|
import java.math.BigInteger;
|
import java.util.*;
|
import java.util.concurrent.CompletableFuture;
|
public class DifyStreamClient {
|
private final OkHttpClient client;
|
private final String apiUrl;
|
private final String apiKey;
|
private final Gson gson;
|
private String prompt;
|
private AiBotMessageService aiBotMessageService;
|
|
public DifyStreamClient(String apiUrl, String apiKey, AiBotMessageService aiBotMessageService) {
|
this.apiUrl = apiUrl;
|
this.apiKey = apiKey;
|
this.gson = new GsonBuilder().setPrettyPrinting().create();
|
this.client = new OkHttpClient.Builder().build();
|
this.aiBotMessageService = aiBotMessageService;
|
}
|
|
// 流式聊天方法 - 直接集成SseEmitter
|
public CompletableFuture<Void> chatStream(String message, String userId, MySseEmitter emitter, String sessionId, BigInteger botId) {
|
prompt = message;
|
QueryWrapper qw = new QueryWrapper();
|
qw.eq(AiBotMessage::getSessionId, sessionId)
|
.orderBy(AiBotMessage::getId,false)
|
.limit(6);
|
|
List<AiBotMessage> history = aiBotMessageService.list(qw);
|
// System.out.println("======================history==================\n");
|
// for (AiBotMessage aiBotMessage : history) {
|
// System.out.println(aiBotMessage.toMessage());
|
// }
|
// System.out.println("\n======================history==================");
|
|
// 构建请求JSON
|
JsonObject requestBody = new JsonObject();
|
requestBody.add("inputs", new JsonObject());
|
requestBody.addProperty("query", message);
|
requestBody.addProperty("response_mode", "streaming");
|
requestBody.addProperty("conversation_id", "");
|
requestBody.addProperty("user", userId);
|
requestBody.add("files", new JsonArray());
|
// 添加历史对话信息
|
JsonArray historyArray = new JsonArray();
|
for (AiBotMessage msg : history) {
|
historyArray.add(String.valueOf(msg));
|
}
|
requestBody.add("history", historyArray);
|
|
RequestBody body = RequestBody.create(
|
gson.toJson(requestBody),
|
MediaType.parse("application/json; charset=utf-8")
|
);
|
|
Request request = new Request.Builder()
|
.url(apiUrl)
|
.post(body)
|
.header("Authorization", apiKey)
|
.header("Content-Type", "application/json")
|
.build();
|
|
CompletableFuture<Void> future = new CompletableFuture<>();
|
|
// 设置SseEmitter生命周期回调
|
emitter.onTimeout(() -> {
|
System.out.println("SSE连接超时");
|
emitter.complete();
|
future.complete(null);
|
});
|
|
emitter.onCompletion(() -> {
|
System.out.println("SSE连接已完成");
|
future.complete(null);
|
});
|
|
emitter.onError(e -> {
|
System.out.println("SSE连接错误: " + e.getMessage());
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
});
|
|
// 发送异步请求
|
client.newCall(request).enqueue(new Callback() {
|
@Override
|
public void onFailure(Call call, IOException e) {
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
}
|
|
@Override
|
public void onResponse(Call call, Response response) {
|
AiMessage aiMessage = new AiMessage();
|
com.agentsflex.core.llm.response.AiMessageResponse aiMessageResponse
|
= new com.agentsflex.core.llm.response.AiMessageResponse(new HistoriesPrompt(), response.message(), aiMessage);
|
try (ResponseBody responseBody = response.body()) {
|
if (!response.isSuccessful()) {
|
emitter.completeWithError(new IOException("API错误: " + response.code()));
|
return;
|
}
|
|
// 使用BufferedSource逐行读取响应内容
|
BufferedSource source = responseBody.source();
|
String line;
|
StringBuffer sb = new StringBuffer();
|
// 标记是否为第一条有效数据(用于处理某些API的特殊格式)
|
boolean isFirstData = true;
|
|
while ((line = source.readUtf8Line()) != null) {
|
if (line.startsWith("data: ")) {
|
String data = line.substring(6).trim();
|
|
// 忽略空数据或结束标记
|
if (data.isEmpty() || data.equals("[DONE]")) {
|
continue;
|
}
|
|
try {
|
// 根据实际API响应结构提取消息内容
|
// 这里需要根据具体API调整路径
|
JsonObject messageObj = gson.fromJson(data, JsonObject.class);
|
// System.out.println(messageObj);
|
if(messageObj.get("event").getAsString().equals("message_end")){
|
AiBotMessage aiBotMessage = new AiBotMessage();
|
aiBotMessage.setBotId(botId);
|
aiBotMessage.setSessionId(sessionId);
|
aiBotMessage.setAccountId(new BigInteger(userId));
|
aiBotMessage.setRole("assistant");
|
aiBotMessage.setContent(sb.toString());
|
aiBotMessage.setCreated(new Date());
|
aiBotMessage.setIsExternalMsg(1);
|
aiBotMessageService.save(aiBotMessage);
|
// System.out.println("end");
|
}else{
|
String context = messageObj.get("answer").getAsString();
|
if (context != null) {
|
// // 只移除HTML标签,保留Markdown特殊字符
|
// context = context.replaceAll("(?i)<[^>]*>", "");
|
context = context.replaceFirst("Thinking...", "");
|
}
|
sb.append(context);
|
aiMessage.setContent(context);
|
aiMessageResponse.setMessage(aiMessage);
|
if (StringUtil.hasText(messageObj.get("answer").getAsString())) {
|
// System.out.println(aiMessage);
|
if(aiMessage.getContent().startsWith("</details>")){
|
aiMessage.setContent(aiMessage.getContent().replaceAll("(?i)<[^>]*>", "</details>"+"\n\n"));
|
}
|
// 发送消息片段给前端
|
emitter.send(JSON.toJSONString(aiMessage));
|
}
|
}
|
|
// 重置第一条数据标记
|
isFirstData = false;
|
|
} catch (Exception e) {
|
// 记录解析错误但继续处理后续数据
|
emitter.completeWithError(e);
|
}
|
}
|
}
|
|
// 所有数据处理完毕,发送完成信号
|
emitter.send(SseEmitter.event().name("complete"));
|
emitter.complete();
|
|
} catch (IOException e) {
|
emitter.completeWithError(e);
|
emitter.complete();
|
} catch (Exception e) {
|
// 处理其他异常
|
emitter.completeWithError(e);
|
}
|
}
|
|
// @Override
|
public void onResponses(Call call, Response response) {
|
// try (ResponseBody responseBody = response.body()) {
|
// if (!response.isSuccessful() || responseBody == null) {
|
// IOException e = new IOException("Unexpected code " + response);
|
// emitter.completeWithError(e);
|
// future.completeExceptionally(e);
|
// return;
|
// }
|
//// String rawResponse = responseBody.string(); // 打印原始响应
|
//// System.out.println("Dify原始响应: " + rawResponse);
|
//
|
// BufferedSource source = responseBody.source();
|
// String line;
|
//
|
// // 处理流式响应
|
// while ((line = source.readUtf8Line()) != null) {
|
// if (line.startsWith("data: ")) {
|
// String data = line.substring(6); // 移除"data: "前缀
|
//
|
// // 跳过空数据
|
// if (data.trim().equals("[DONE]")) {
|
// emitter.complete();
|
// future.complete(null);
|
// break;
|
// }
|
//
|
// try {
|
// // 解析JSON响应
|
// JsonObject jsonResponse = gson.fromJson(data, JsonObject.class);
|
// // 发送消息内容到客户端
|
// if (jsonResponse.has("message") && jsonResponse.get("message").isJsonObject()) {
|
// String content = jsonResponse.getAsJsonObject("message").get("content").getAsString();
|
// emitter.send(SseEmitter.event()
|
// .data(content)
|
// .name("message")
|
// );
|
// }
|
//
|
// // 处理函数调用
|
// if (jsonResponse.has("function_call") && jsonResponse.get("function_call").isJsonObject()) {
|
// emitter.send(SseEmitter.event()
|
// .data(jsonResponse.get("function_call"))
|
// .name("function_call")
|
// );
|
// }
|
//
|
// } catch (Exception e) {
|
// emitter.completeWithError(e);
|
// future.completeExceptionally(e);
|
// break;
|
// }
|
// }
|
// }
|
try (ResponseBody responseBody = response.body()) {
|
if (!response.isSuccessful()) {
|
emitter.completeWithError(new IOException("API错误"));
|
return;
|
}
|
StringBuilder fullAnswer = new StringBuilder();
|
BufferedSource source = responseBody.source();
|
String line;
|
while ((line = source.readUtf8Line()) != null) {
|
if (line.startsWith("data: ")) {
|
String data = line.substring(6).trim();
|
|
if (data.equals("[DONE]")) {
|
// 发送完整回答并结束流
|
emitter.send(fullAnswer.toString());
|
emitter.complete();
|
break;
|
}
|
|
try {
|
JsonObject json = gson.fromJson(data, JsonObject.class);
|
System.out.println(json);
|
String answerFragment = json.get("answer").getAsString();
|
// 过滤无效字符(如️⃣、\n)
|
String cleanedFragment = answerFragment.replaceAll("[^\u4e00-\u9fa50-9a-zA-Z\\s\\p{Punct}]", "");
|
fullAnswer.append(cleanedFragment); // 拼接碎片
|
|
// 可选:每拼接一部分发送给前端(需前端支持增量渲染)
|
// emitter.send(SseEmitter.event().data(cleanedFragment));
|
|
} catch (Exception e) {
|
emitter.sendAndComplete("解析错误: " + e.getMessage());
|
}
|
}
|
}
|
} catch (Exception e) {
|
emitter.completeWithError(e);
|
future.completeExceptionally(e);
|
}
|
}
|
});
|
|
return future;
|
}
|
|
// 数据模型类
|
public interface StreamResponseListener {
|
void onMessage(ChatContext context, AiMessageResponse response);
|
void onStop(ChatContext context);
|
void onFailure(ChatContext context, Throwable throwable);
|
}
|
|
public static class ChatContext {
|
// 上下文信息
|
}
|
|
public static class AiMessageResponse {
|
private Message message;
|
private JsonArray functionCallers;
|
|
public Message getMessage() { return message; }
|
public void setMessage(Message message) { this.message = message; }
|
public JsonArray getFunctionCallers() { return functionCallers; }
|
public void setFunctionCallers(JsonArray functionCallers) { this.functionCallers = functionCallers; }
|
}
|
|
public static class Message {
|
private String content;
|
|
public String getContent() { return content; }
|
public void setContent(String content) { this.content = content; }
|
}
|
}
|