MyBatisPlusJoin单表Join多次

MyBatis Plus Join
用的是

1
2
3
4
5
<dependency>
<groupId>com.github.yulichang</groupId>
<artifactId>mybatis-plus-join</artifactId>
<version>1.2.4</version>
</dependency>

遇到的问题

大概是 一个轨迹逻辑信息表 需要关联多个点位的名称
点位表 需要被 轨迹信息关联多次
而 我使用的这个 MyBatisPlusJoin 无法关联单表多次 (特指 Lambda 形式)

解决方式

研究源码发现
无法链接单表多次的主要原因是

1
2
3
4
5
6
7
8
9
10
/**
* 关联的表
*/
protected Map<Class<?>, Integer> subTable = new HashMap<>();

/**
* 表序号
*/
private int tableIndex = 1;

框架用于记录本次查询的SQL 使用了
Class 作为 Key
tableIndex 作为 Value
导致 同一张表 Join 多次 生成出来的SQL
都会是最后一次JOIN的那个表

1
2
join.leftJoin(FwServicePoint.class, FwServicePoint::getServicePointId, FwTrackDistance::getSpBeginId);
join.leftJoin(FwServicePoint.class, FwServicePoint::getServicePointId, FwTrackDistance::getSpEndId);

同样的 FwServicePoint.class 给到的 TABLE_ALIAS 会是一样的
生成出来的SQL会是
JOIN FW_SERVICE_POINT t2 on(xx=xx)
JOIN FW_SERVICE_POINT t2 on(xx=xx)

这样就有两种方式去解决

一种是
将 FwServicePoint 这张表 建立多个映射的 javabean

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
join.leftJoin(FwServicePointStart.class, FwServicePointStart::getServicePointId, FwTrackDistance::getSpBeginId);
join.leftJoin(FwServicePointEnd.class, FwServicePointEnd::getServicePointId, FwTrackDistance::getSpEndId);


@TableName("FW_SERVICE_POINT")
public class FwServicePoint {
...
}
@TableName("FW_SERVICE_POINT")
public class FwServicePointStart {
//复制出来只用于 join的映射对象
}
@TableName("FW_SERVICE_POINT")
public class FwServicePointEnd {
//复制出来只用于 join的映射对象
}

另一种就是修改框架了

1
2
3
4
5
6
7
8
9
10
11
/**
* 别名对应表名
*/
protected Map<String, Integer> tableAliasMap = new HashMap<>();
/*
大概逻辑就是 leftjoin 时携带一个 表别名 然后记录 当前表 join 时的 tableIndex
生成 join条件时 从 tableAliasMap 去获取 tableIndex 而不是 subTable

修改 MPJLambdaWrapper 中的 SelectColumn 添加 tableAlias 字段
在 leftJoin 方法中添加 tableAlias 参数
*/

修改后调用方式为

1
2
3
4
join.leftJoin(FwServicePoint.class, FwServicePoint::getServicePointId, FwTrackDistance::getSpBeginId,"start");
join.leftJoin(FwServicePoint.class, FwServicePoint::getServicePointId, FwTrackDistance::getSpEndId, "end");
join.selectAs("start", FwServicePoint::getPlaceName, FwTrackDistance::getSpBeginName);
join.selectAs("end", FwServicePoint::getPlaceName, FwTrackDistance::getSpEndName);

修改源码如下

JoinWrapper = MPJLambdaWrapper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import com.baomidou.mybatisplus.core.conditions.SharedString;
import com.baomidou.mybatisplus.core.conditions.segments.MergeSegments;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.ArrayUtils;
import com.baomidou.mybatisplus.core.toolkit.Assert;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.github.yulichang.toolkit.Constant;
import com.github.yulichang.toolkit.LambdaUtils;
import com.github.yulichang.wrapper.enums.BaseFuncEnum;
import com.github.yulichang.wrapper.interfaces.Query;
import com.mysd.framework.mybatis.wrapper.interfaces.QuickJoinFunction;

import lombok.Data;
import lombok.Getter;

/**
* 魔改 mybatis-plus-join
*
*
* 允许两张表 允许同时 join 多次
* selectAll 忽略 select = false 的字段
*/
public class JoinWrapper<T> extends JoinAbstractLambdaWrapper<T, JoinWrapper<T>>
implements Query<JoinWrapper<T>>, LambdaJoin<JoinWrapper<T>> {

/**
*
*/
private static final long serialVersionUID = 1L;

/**
* 查询字段 sql
*/
private SharedString sqlSelect = new SharedString();

/**
* 查询表
*/
private final SharedString from = new SharedString();

/**
* 主表别名
*/
private final SharedString alias = new SharedString(Constant.TABLE_ALIAS);

/**
* 查询的字段
*/
private final List<SelectColumn> selectColumns = new ArrayList<>();

/**
* 忽略查询的字段
*/
private final List<SelectColumn> ignoreColumns = new ArrayList<>();

/**
* 是否 select distinct
*/
private boolean selectDistinct = false;

/**
* 表序号
*/
private int tableIndex = 1;

/**
* ON sql wrapper集合
*/
private final List<JoinWrapper<?>> onWrappers = new ArrayList<>();

/**
* 连表关键字 on 条件 func 使用
*/
@Getter
private String keyWord;

/**
* 连表实体类 on 条件 func 使用
*/
@Getter
private Class<?> joinClass;

/**
* 连表别名
*/
@Getter
private String tableAlias;

/**
* 不建议直接 new 该实例,使用 JoinWrapper.<UserDO>lambdaQuery()
*/
public JoinWrapper() {
super.initNeed();
}

/**
* 不建议直接 new 该实例,使用 JoinWrapper.<UserDO>lambdaQuery()
*/
JoinWrapper(T entity, Class<T> entityClass, SharedString sqlSelect, AtomicInteger paramNameSeq,
Map<String, Object> paramNameValuePairs, MergeSegments mergeSegments, SharedString lastSql,
SharedString sqlComment, SharedString sqlFirst, Map<Class<?>, Integer> subTable, String keyWord,
Class<?> joinClass, String tableAlias) {
super.setEntity(entity);
super.setEntityClass(entityClass);
this.paramNameSeq = paramNameSeq;
this.paramNameValuePairs = paramNameValuePairs;
this.expression = mergeSegments;
this.sqlSelect = sqlSelect;
this.lastSql = lastSql;
this.sqlComment = sqlComment;
this.sqlFirst = sqlFirst;
this.subTable = subTable;
this.keyWord = keyWord;
this.joinClass = joinClass;
this.tableAlias = tableAlias;
}

/**
* sql去重 select distinct
*/
public JoinWrapper<T> distinct() {
this.selectDistinct = true;
return typedThis;
}

@Override
@SafeVarargs
public final <S> JoinWrapper<T> select(SFunction<S, ?>... columns) {
if (ArrayUtils.isNotEmpty(columns)) {
for (SFunction<S, ?> s : columns) {
selectColumns.add(SelectColumn.of(LambdaUtils.getEntityClass(s), getCache(s).getColumn()));
}
}
return typedThis;
}

@Override
public <E> JoinWrapper<T> select(Class<E> entityClass, Predicate<TableFieldInfo> predicate) {
TableInfo info = TableInfoHelper.getTableInfo(entityClass);
Assert.notNull(info, "table can not be find");
info.getFieldList().stream().filter(predicate).collect(Collectors.toList())
.forEach(i -> selectColumns.add(SelectColumn.of(entityClass, i.getColumn())));
return typedThis;
}

public <S, X> JoinWrapper<T> selectAs(String tableAlias, SFunction<S, ?> column, SFunction<X, ?> alias) {
selectColumns.add(SelectColumn.of(LambdaUtils.getEntityClass(column), getCache(column).getColumn(),
LambdaUtils.getName(alias), tableAlias));
return typedThis;
}

@Override
public <S> JoinWrapper<T> selectAs(SFunction<S, ?> column, String alias) {
selectColumns.add(SelectColumn.of(LambdaUtils.getEntityClass(column), getCache(column).getColumn(), alias));
return typedThis;
}

public <S> JoinWrapper<T> selectFunc(boolean condition, BaseFuncEnum funcEnum, SFunction<S, ?> column,
String alias) {
if (condition) {
selectColumns.add(
SelectColumn.of(LambdaUtils.getEntityClass(column), getCache(column).getColumn(), alias, funcEnum));
}
return typedThis;
}

@Override
public JoinWrapper<T> selectFunc(boolean condition, BaseFuncEnum funcEnum, Object column, String alias) {
if (condition) {
selectColumns.add(SelectColumn.of(null, column.toString(), alias, funcEnum));
}
return typedThis;
}

public final JoinWrapper<T> selectAll(Class<?> clazz) {
TableInfo info = TableInfoHelper.getTableInfo(clazz);
Assert.notNull(info, "table can not be find -> %s", clazz);
if (info.havePK()) {
selectColumns.add(SelectColumn.of(clazz, info.getKeyColumn()));
}
info.getFieldList().forEach(c -> {
if (c.isSelect())
selectColumns.add(SelectColumn.of(clazz, c.getColumn()));
});
return typedThis;
}

@Override
@SafeVarargs
public final <S> JoinWrapper<T> selectIgnore(SFunction<S, ?>... columns) {
if (ArrayUtils.isNotEmpty(columns)) {
for (SFunction<S, ?> s : columns) {
ignoreColumns.add(SelectColumn.of(LambdaUtils.getEntityClass(s), getCache(s).getColumn()));
}
}
return typedThis;
}

/**
* 查询条件 SQL 片段
*/
@Override
public String getSqlSelect() {
if (StringUtils.isBlank(sqlSelect.getStringValue())) {
if (CollectionUtils.isNotEmpty(ignoreColumns)) {
selectColumns.removeIf(c -> c.getFuncEnum() == null && ignoreColumns.stream().anyMatch(
i -> i.getClazz() == c.getClazz() && Objects.equals(c.getColumnName(), i.getColumnName())));
}
String s = selectColumns.stream().map(i -> {
String str = Constant.TABLE_ALIAS + getDefault(
i.getTableAlias() != null ? tableAliasMap.get(i.getTableAlias()) : subTable.get(i.getClazz()))
+ StringPool.DOT + i.getColumnName();
return (i.getFuncEnum() == null ? str : String.format(i.getFuncEnum().getSql(), str))
+ (StringUtils.isBlank(i.getAlias()) ? StringPool.EMPTY : (Constant.AS + i.getAlias()));
}).collect(Collectors.joining(StringPool.COMMA));
sqlSelect.setStringValue(s);
}
return sqlSelect.getStringValue();
}

/**
* 获取连表部分语句
*/
public String getFrom() {
if (StringUtils.isBlank(from.getStringValue())) {
StringBuilder value = new StringBuilder();
for (JoinWrapper<?> wrapper : onWrappers) {
String tableName = TableInfoHelper.getTableInfo(wrapper.getJoinClass()).getTableName();

String tableNameDefAlias = Constant.TABLE_ALIAS + subTable.get(wrapper.getJoinClass());
String tableNameAlias = Constant.TABLE_ALIAS
+ (wrapper.getTableAlias() != null ? tableAliasMap.get(wrapper.getTableAlias())
: subTable.get(wrapper.getJoinClass()));

String whereSql = wrapper.getExpression().getNormal().getSqlSegment();
// 为了实现 多次关联的表 可以写在 左边 也可以写在 右边 使用 replace 进行补救
if (!tableNameDefAlias.equals(tableNameAlias))
whereSql = whereSql.replaceAll(tableNameDefAlias + StringPool.BACK_SLASH + StringPool.DOT,
tableNameAlias + StringPool.BACK_SLASH + StringPool.DOT);

value.append(wrapper.getKeyWord()).append(tableName).append(StringPool.SPACE).append(tableNameAlias)
.append(Constant.ON).append(whereSql);
}
from.setStringValue(value.toString());
}
return from.getStringValue();
}

public String getAlias() {
return alias.getStringValue();
}

public boolean getSelectDistinct() {
return selectDistinct;
}

/**
* 用于生成嵌套 sql
* <p>
* 故 sqlSelect 不向下传递
* </p>
*/
@Override
protected JoinWrapper<T> instance() {
return instance(null, null, null);
}

protected JoinWrapper<T> instance(String keyWord, Class<?> joinClass, String tableAlias) {
return new JoinWrapper<>(getEntity(), getEntityClass(), null, paramNameSeq, paramNameValuePairs,
new MergeSegments(), SharedString.emptyString(), SharedString.emptyString(), SharedString.emptyString(),
this.subTable, keyWord, joinClass, tableAlias);
}

@Override
public void clear() {
super.clear();
sqlSelect.toNull();
from.toNull();
selectColumns.clear();
ignoreColumns.clear();
subTable.clear();
tableAliasMap.clear();
}

@Override
public <R> JoinWrapper<T> join(String keyWord, boolean condition, Class<R> clazz, QuickJoinFunction function,
String tableAlias) {
if (condition) {
JoinWrapper<?> apply = function.apply(instance(keyWord, clazz, tableAlias));
onWrappers.add(apply);
subTable.put(clazz, tableIndex);
if (tableAlias != null)
tableAliasMap.put(tableAlias, tableIndex);
tableIndex++;
}
return typedThis;
}

public JoinWrapper<T> limit(Integer num){
return last("limit " + num);
}

/**
* select字段
*/
@Data
private static class SelectColumn {

/**
* 字段实体类
*/
private Class<?> clazz;

/**
* 表 别名
*/
private String tableAlias;

/**
* 数据库字段名
*/
private String columnName;

/**
* 字段别名
*/
private String alias;

/**
* 字段函数
*/
private BaseFuncEnum funcEnum;

/**
* 自定义函数填充参数
*/
private List<SFunction<?, ?>> funcArgs;

private SelectColumn(Class<?> clazz, String columnName, String alias, String tableAlias,
BaseFuncEnum funcEnum) {
this.clazz = clazz;
this.columnName = columnName;
this.alias = alias;
this.tableAlias = tableAlias;
this.funcEnum = funcEnum;
}

public static SelectColumn of(Class<?> clazz, String columnName) {
return new SelectColumn(clazz, columnName, null, null, null);
}

public static SelectColumn of(Class<?> clazz, String columnName, String alias) {
return new SelectColumn(clazz, columnName, alias, null, null);
}

public static SelectColumn of(Class<?> clazz, String columnName, String alias, BaseFuncEnum funcEnum) {
return new SelectColumn(clazz, columnName, alias, null, funcEnum);
}

public static SelectColumn of(Class<?> clazz, String columnName, String alias, String tableAlias) {
return new SelectColumn(clazz, columnName, alias, tableAlias, null);
}
}
}

JoinAbstractLambdaWrapper = MPJAbstractLambdaWrapper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.github.yulichang.toolkit.Constant;
import com.github.yulichang.toolkit.LambdaUtils;
import com.github.yulichang.wrapper.MPJAbstractWrapper;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static java.util.stream.Collectors.joining;

public abstract class JoinAbstractLambdaWrapper<T, Children extends JoinAbstractLambdaWrapper<T, Children>>
extends MPJAbstractWrapper<T, Children> {

/**
* 关联的表
*/
protected Map<Class<?>, Integer> subTable = new HashMap<>();

/**
* 别名对应表名
*/
protected Map<String, Integer> tableAliasMap = new HashMap<>();

/**
* 缓存字段
*/
protected Map<Class<?>, Map<String, ColumnCache>> columnMap = new HashMap<>();

@Override
protected <X> String columnToString(X column) {
return columnToString((SFunction<?, ?>) column);
}

@Override
@SafeVarargs
protected final <X> String columnsToString(X... columns) {
return Arrays.stream(columns).map(i -> columnToString((SFunction<?, ?>) i)).collect(joining(StringPool.COMMA));
}

protected String columnToString(SFunction<?, ?> column) {
//这里还是用的 subTable 因为我没有在 where 里面过滤关联表的需求 就不写了
return Constant.TABLE_ALIAS + getDefault(subTable.get(LambdaUtils.getEntityClass(column))) + StringPool.DOT +
getCache(column).getColumn();
}

protected ColumnCache getCache(SFunction<?, ?> fn) {
Class<?> aClass = LambdaUtils.getEntityClass(fn);
Map<String, ColumnCache> cacheMap = columnMap.get(aClass);
if (cacheMap == null) {
cacheMap = LambdaUtils.getColumnMap(aClass);
columnMap.put(aClass, cacheMap);
}
return cacheMap.get(LambdaUtils.formatKey(LambdaUtils.getName(fn)));
}

protected String getDefault(Integer i) {
if (Objects.nonNull(i)) {
return i.toString();
}
return StringPool.EMPTY;
}

}

LambdaJoin

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.github.yulichang.interfaces.MPJBaseJoin;
import com.github.yulichang.toolkit.Constant;
import com.mysd.framework.mybatis.wrapper.interfaces.QuickJoinFunction;

public interface LambdaJoin<Children> extends MPJBaseJoin {

/**
* left join
*
* @param clazz 关联的实体类
* @param left 条件
* @param right 条件
*/
default <T, X> Children leftJoin(Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right) {
return leftJoin(true, clazz, left, right, null);
}

default <T, X> Children leftJoin(Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right, String tableAlias) {
return leftJoin(true, clazz, left, right, tableAlias);
}

/**
* left join
* <p>
* 例 leftJoin(UserDO.class, on ->
* on.eq(UserDO::getId,UserAddressDO::getUserId).le().gt()...)
*
* @param clazz 关联的实体类
* @param function 条件
*/
default <T> Children leftJoin(Class<T> clazz, QuickJoinFunction function) {
return leftJoin(true, clazz, function, null);
}

default <T> Children leftJoin(Class<T> clazz, QuickJoinFunction function, String tableAlias) {
return leftJoin(true, clazz, function, tableAlias);
}

/**
* left join
*
* @param condition 是否执行
* @param clazz 关联的实体类
* @param left 条件
* @param right 条件
*/
default <T, X> Children leftJoin(boolean condition, Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right,
String tableAlias) {
return leftJoin(condition, clazz, on -> {
return on.eq(left, right);
}, tableAlias);
}

/**
* left join
* <p>
* 例 leftJoin(UserDO.class, on ->
* on.eq(UserDO::getId,UserAddressDO::getUserId).le().gt()...)
*
* @param condition 是否执行
* @param clazz 关联实体类
* @param function 条件
*/
default <T> Children leftJoin(boolean condition, Class<T> clazz, QuickJoinFunction function, String tableAlias) {
return join(Constant.LEFT_JOIN, condition, clazz, function, tableAlias);
}

/**
* ignore 参考 left join
*/
default <T, X> Children rightJoin(Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right) {
return rightJoin(true, clazz, left, right);
}

/**
* ignore 参考 left join
*/
default <T> Children rightJoin(Class<T> clazz, QuickJoinFunction function) {
return rightJoin(true, clazz, function);
}

/**
* ignore 参考 left join
*/
default <T, X> Children rightJoin(boolean condition, Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right) {
return rightJoin(condition, clazz, on -> on.eq(left, right));
}

/**
* ignore 参考 left join
*/
default <T> Children rightJoin(boolean condition, Class<T> clazz, QuickJoinFunction function) {
return join(Constant.RIGHT_JOIN, condition, clazz, function, null);
}

/**
* ignore 参考 left join
*/
default <T, X> Children innerJoin(Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right) {
return innerJoin(true, clazz, left, right);
}

/**
* ignore 参考 left join
*/
default <T> Children innerJoin(Class<T> clazz, QuickJoinFunction function) {
return innerJoin(true, clazz, function);
}

/**
* ignore 参考 left join
*/
default <T, X> Children innerJoin(boolean condition, Class<T> clazz, SFunction<T, ?> left, SFunction<X, ?> right) {
return innerJoin(condition, clazz, on -> on.eq(left, right));
}

/**
* ignore 参考 left join
*/
default <T> Children innerJoin(boolean condition, Class<T> clazz, QuickJoinFunction function) {
return join(Constant.INNER_JOIN, condition, clazz, function, null);
}

/**
* 查询基类 可以直接调用此方法实现以上所有功能
*
* @param keyWord 连表关键字
* @param condition 是否执行
* @param clazz 连表实体类
* @param function 关联条件
*/
<T> Children join(String keyWord, boolean condition, Class<T> clazz, QuickJoinFunction function, String tableAlias);

}