#IC1P55 #C17PK
1.修改function_call模型调用方式,stream=false
2.修复AiWorkflowFunction中invoke方法未对Tinyflow设置provider
2个文件已修改
194 ■■■■■ 已修改文件
aiflowy-modules/aiflowy-module-ai/src/main/java/tech/aiflowy/ai/controller/AiBotController.java 145 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
aiflowy-modules/aiflowy-module-ai/src/main/java/tech/aiflowy/ai/entity/AiWorkflowFunction.java 49 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
aiflowy-modules/aiflowy-module-ai/src/main/java/tech/aiflowy/ai/controller/AiBotController.java
@@ -151,81 +151,100 @@
        final Boolean[] needClose = {true};
        ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (!humanMessage.getFunctions().isEmpty()) {
            try {
                AiMessageResponse aiMessageResponse = llm.chat(historiesPrompt);
                function_call(aiMessageResponse, emitter, needClose, historiesPrompt, llm, prompt);
            } catch (Exception e) {
                emitter.completeWithError(e);
            }
        llm.chatStream(historiesPrompt, new StreamResponseListener() {
            @Override
            public void onMessage(ChatContext context, AiMessageResponse response) {
                try {
                    RequestContextHolder.setRequestAttributes(sra, true);
                    String content = response.getMessage().getContent();
                    Object messageContent = response.getMessage();
                    if (StringUtil.hasText(content)) {
                        String jsonResult = JSON.toJSONString(messageContent);
                        emitter.send(jsonResult);
            if (needClose[0]) {
                System.out.println("function chat complete");
                emitter.complete();
            }
        } else {
            llm.chatStream(historiesPrompt, new StreamResponseListener() {
                @Override
                public void onMessage(ChatContext context, AiMessageResponse response) {
                    try {
                        function_call(response, emitter, needClose, historiesPrompt, llm, prompt);
                    } catch (Exception e) {
                        emitter.completeWithError(e);
                    }
                    List<FunctionCaller> functionCallers = response.getFunctionCallers();
                    if (CollectionUtil.hasItems(functionCallers)) {
                        needClose[0] = false;
                        for (FunctionCaller functionCaller : functionCallers) {
                            Object result = functionCaller.call();
                            if (ObjectUtil.isNotEmpty(result)) {
                }
                                String newPrompt = "请根据以下内容回答用户,内容是:\n" + result + "\n 用户的问题是:" + prompt;
                                historiesPrompt.addMessageTemporary(new HumanMessage(newPrompt));
                                llm.chatStream(historiesPrompt, new StreamResponseListener() {
                                    @Override
                                    public void onMessage(ChatContext context, AiMessageResponse response) {
                                        needClose[0] = true;
                                        String content = response.getMessage().getContent();
                                        Object messageContent = response.getMessage();
                                        if (StringUtil.hasText(content)) {
                                            String jsonResult = JSON.toJSONString(messageContent);
                                            emitter.send(jsonResult);
                                        }
                                    }
                                    @Override
                                    public void onStop(ChatContext context) {
                                        if (needClose[0]) {
                                            System.out.println("function chat complete");
                                            emitter.complete();
                                        }
                                        historiesPrompt.clearTemporaryMessages();
                                    }
                                    @Override
                                    public void onFailure(ChatContext context, Throwable throwable) {
                                        emitter.completeWithError(throwable);
                                    }
                                });
                            }
                        }
                @Override
                public void onStop(ChatContext context) {
                    if (needClose[0]) {
                        System.out.println("normal chat complete");
                        emitter.complete();
                    }
                } catch (Exception e) {
                    emitter.completeWithError(e);
                }
            }
            @Override
            public void onStop(ChatContext context) {
                if (needClose[0]) {
                    System.out.println("normal chat complete");
                    emitter.complete();
                @Override
                public void onFailure(ChatContext context, Throwable throwable) {
                    emitter.completeWithError(throwable);
                }
            }
            @Override
            public void onFailure(ChatContext context, Throwable throwable) {
                emitter.completeWithError(throwable);
            }
        });
            });
        }
        return emitter;
    }
    private void function_call(AiMessageResponse aiMessageResponse, MySseEmitter emitter, Boolean[] needClose, HistoriesPrompt historiesPrompt, Llm llm, String prompt) {
        ServletRequestAttributes sra = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        RequestContextHolder.setRequestAttributes(sra, true);
        String content = aiMessageResponse.getMessage().getContent();
        Object messageContent = aiMessageResponse.getMessage();
        if (StringUtil.hasText(content)) {
            String jsonResult = JSON.toJSONString(messageContent);
            emitter.send(jsonResult);
        }
        List<FunctionCaller> functionCallers = aiMessageResponse.getFunctionCallers();
        if (CollectionUtil.hasItems(functionCallers)) {
            needClose[0] = false;
            for (FunctionCaller functionCaller : functionCallers) {
                Object result = functionCaller.call();
                if (ObjectUtil.isNotEmpty(result)) {
                    String newPrompt = "请根据以下内容回答用户,内容是:\n" + result + "\n 用户的问题是:" + prompt;
                    historiesPrompt.addMessageTemporary(new HumanMessage(newPrompt));
                    llm.chatStream(historiesPrompt, new StreamResponseListener() {
                        @Override
                        public void onMessage(ChatContext context, AiMessageResponse response) {
                            needClose[0] = true;
                            String content = response.getMessage().getContent();
                            Object messageContent = response.getMessage();
                            if (StringUtil.hasText(content)) {
                                String jsonResult = JSON.toJSONString(messageContent);
                                emitter.send(jsonResult);
                            }
                        }
                        @Override
                        public void onStop(ChatContext context) {
                            if (needClose[0]) {
                                System.out.println("function chat complete");
                                emitter.complete();
                            }
                            historiesPrompt.clearTemporaryMessages();
                        }
                        @Override
                        public void onFailure(ChatContext context, Throwable throwable) {
                            emitter.completeWithError(throwable);
                        }
                    });
                }
            }
        }
    }
    private void appendWorkflowFunctions(BigInteger botId, HumanMessage humanMessage) {
        QueryWrapper queryWrapper = QueryWrapper.create().eq(AiBotWorkflow::getBotId, botId);
        List<AiBotWorkflow> aiBotWorkflows = aiBotWorkflowService.getMapper().selectListWithRelationsByQuery(queryWrapper);
aiflowy-modules/aiflowy-module-ai/src/main/java/tech/aiflowy/ai/entity/AiWorkflowFunction.java
@@ -1,5 +1,7 @@
package tech.aiflowy.ai.entity;
import tech.aiflowy.ai.service.AiKnowledgeService;
import tech.aiflowy.ai.service.AiLlmService;
import tech.aiflowy.ai.service.AiWorkflowService;
import tech.aiflowy.common.util.SpringContextUtil;
import com.agentsflex.core.chain.Chain;
@@ -58,13 +60,58 @@
        AiWorkflowService service = SpringContextUtil.getBean(AiWorkflowService.class);
        AiWorkflow workflow = service.getById(this.workflowId);
        if (workflow != null) {
            Chain chain = workflow.toTinyflow().toChain();
            Tinyflow tinyflow = workflow.toTinyflow();
            setLlmProvider(tinyflow);
            setKnowledgeProvider(tinyflow);
            Chain chain = tinyflow.toChain();
            return chain.executeForResult(argsMap);
        } else {
            throw new RuntimeException("can not find the workflow by id: " + this.workflowId);
        }
    }
    private void setLlmProvider( Tinyflow tinyflow){
        AiLlmService aiLlmService = SpringContextUtil.getBean(AiLlmService.class);
        tinyflow.setLlmProvider(new LlmProvider() {
            @Override
            public Llm getLlm(Object id) {
                AiLlm aiLlm = aiLlmService.getById(new BigInteger(id.toString()));
                return aiLlm.toLlm();
            }
        });
    }
    private void setKnowledgeProvider( Tinyflow tinyflow){
        AiLlmService aiLlmService = SpringContextUtil.getBean(AiLlmService.class);
        AiKnowledgeService aiKnowledgeService= SpringContextUtil.getBean(AiKnowledgeService.class);
        tinyflow.setKnowledgeProvider(new KnowledgeProvider() {
            @Override
            public Knowledge getKnowledge(Object o) {
                AiKnowledge aiKnowledge = aiKnowledgeService.getById(new BigInteger(o.toString()));
                return  new Knowledge() {
                    @Override
                    public List<Document> search(String keyword, int limit, KnowledgeNode knowledgeNode, Chain chain) {
                        DocumentStore documentStore = aiKnowledge.toDocumentStore();
                        if (documentStore == null){
                            return null;
                        }
                        AiLlm aiLlm = aiLlmService.getById(aiKnowledge.getVectorEmbedLlmId());
                        if (aiLlm == null){
                            return null;
                        }
                        documentStore.setEmbeddingModel(aiLlm.toLlm());
                        SearchWrapper wrapper = new SearchWrapper();
                        wrapper.setMaxResults(Integer.valueOf(limit));
                        wrapper.setText(keyword);
                        StoreOptions options = StoreOptions.ofCollectionName(aiKnowledge.getVectorStoreCollection());
                        List<Document> results = documentStore.search(wrapper, options);
                        return results;
                    }
                };
            }
        });
    }
    @Override
    public String toString() {
        return "AiWorkflowFunction{" +