1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package tech.aiflowy.ai.node;
import com.agentsflex.core.chain.Chain;
import com.agentsflex.core.chain.Parameter;
import com.agentsflex.core.chain.node.BaseNode;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.jfinal.template.stat.ast.For;
import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Select;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import tech.aiflowy.common.util.Maps;
import tech.aiflowy.common.web.exceptions.BusinessException;
 
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
 
 
/**
 * SQL查询节点
 *
 * @author tao
 * @date 2025-05-21
 */
public class SqlNode extends BaseNode {
 
    private final String sql;
 
    private static final Logger logger = LoggerFactory.getLogger(SqlNode.class);
 
    public SqlNode(String sql) {
        this.sql = sql;
    }
 
    @Override
    protected Map<String, Object> execute(Chain chain) {
 
        Map<String, Object> map = chain.getParameterValues(this);
        Map<String, Object> res = new HashMap<>();
 
 
        Map<String, Object> formatSqlMap = formatSql(sql,map);
        String formatSql = (String)formatSqlMap.get("replacedSql");
 
        Statement statement = null;
        try {
            statement  = CCJSqlParserUtil.parse(formatSql);
 
        } catch (JSQLParserException e) {
            logger.error("sql 解析报错:",e);
            throw new BusinessException("SQL解析失败,请确认SQL语法无误");
        }
 
        if (!(statement instanceof Select)) {
            logger.error("sql 解析报错:statement instanceof Select 结果为false");
            throw new BusinessException("仅支持查询语句!");
        }
 
        List<String> paramNames = (List<String>) formatSqlMap.get("paramNames");
 
        List<Object> paramValues = new ArrayList<>();
        paramNames.forEach(paramName -> {
            Object o = map.get(paramName);
            paramValues.add(o);
        });
 
        List<Row> rows = Db.selectListBySql(formatSql, paramValues.toArray());
 
        if (rows == null || rows.isEmpty()) {
            return Collections.emptyMap();
        }
 
        res.put("queryData",rows);
        return res;
    }
 
    private Map<String, Object> formatSql(String rawSql, Map<String, Object> paramMap) {
 
        if (!StringUtils.hasLength(rawSql)) {
            logger.error("sql解析报错:sql为空");
            throw new BusinessException("sql 不能为空!");
        }
 
        // 匹配 {{?...}} 表示可用占位符的参数
        Pattern paramPattern = Pattern.compile("\\{\\{\\?([^}]+)}}");
 
        // 匹配 {{...}} 表示直接替换的参数(非占位符)
        Pattern directPattern = Pattern.compile("\\{\\{([^}?][^}]*)}}");
 
        List<String> paramNames = new ArrayList<>();
        StringBuffer sqlBuffer = new StringBuffer();
 
        // 替换 {{?...}}  ->  ?
        Matcher paramMatcher = paramPattern.matcher(rawSql);
        while (paramMatcher.find()) {
            String paramName = paramMatcher.group(1).trim();
            paramNames.add(paramName);
            paramMatcher.appendReplacement(sqlBuffer, "?");
        }
        paramMatcher.appendTail(sqlBuffer);
        String intermediateSql = sqlBuffer.toString();
 
        // 替换 {{...}}  -> 实际值(用于表名/列名等)
        sqlBuffer = new StringBuffer(); // 清空 buffer 重新处理
        Matcher directMatcher = directPattern.matcher(intermediateSql);
        while (directMatcher.find()) {
            String key = directMatcher.group(1).trim();
            Object value = paramMap.get(key);
            if (value == null) {
                logger.error("未找到参数:" + key);
                throw new BusinessException("sql解析失败,请确保sql语法正确!");
            }
 
            String safeValue = value.toString();
 
            directMatcher.appendReplacement(sqlBuffer, Matcher.quoteReplacement(safeValue));
        }
        directMatcher.appendTail(sqlBuffer);
 
        String finalSql = sqlBuffer.toString().trim();
 
        // 清理末尾分号与中文引号
        if (finalSql.endsWith(";") || finalSql.endsWith(";")) {
            finalSql = finalSql.substring(0, finalSql.length() - 1);
        }
        finalSql = finalSql.replace("“", "\"").replace("”", "\"");
 
        logger.info("Final SQL: {}", finalSql);
        logger.info("Param names: {}", paramNames);
 
        Map<String, Object> result = new HashMap<>();
        result.put("replacedSql", finalSql);
        result.put("paramNames", paramNames);
        return result;
    }
 
 
 
    @Override
    public String toString() {
        return "SqlNode{" +
                "sql='" + sql + '\'' +
                ", outputDefs=" + outputDefs +
                ", parameters=" + parameters +
                ", id='" + id + '\'' +
                ", name='" + name + '\'' +
                ", description='" + description + '\'' +
                ", async=" + async +
                ", inwardEdges=" + inwardEdges +
                ", outwardEdges=" + outwardEdges +
                ", condition=" + condition +
                ", memory=" + memory +
                ", nodeStatus=" + nodeStatus +
                '}';
    }
}