1
1
package org .neo4j .graphalgo .similarity ;
2
2
3
+ import org .junit .Before ;
3
4
import org .junit .Rule ;
4
5
import org .junit .Test ;
6
+ import org .junit .runner .RunWith ;
7
+ import org .junit .runners .Parameterized ;
5
8
import org .neo4j .graphdb .Transaction ;
6
9
import org .neo4j .kernel .internal .GraphDatabaseAPI ;
10
+ import org .neo4j .logging .Log ;
7
11
import org .neo4j .logging .NullLog ;
8
12
import org .neo4j .test .rule .ImpermanentDatabaseRule ;
9
13
14
+ import java .lang .invoke .MethodHandle ;
15
+ import java .lang .invoke .MethodHandles ;
16
+ import java .lang .invoke .MethodType ;
17
+ import java .util .Arrays ;
18
+ import java .util .Collection ;
10
19
import java .util .List ;
11
20
import java .util .Objects ;
12
21
import java .util .stream .Collectors ;
17
26
import static org .junit .Assert .assertEquals ;
18
27
import static org .junit .Assert .assertThat ;
19
28
29
+ @ RunWith (Parameterized .class )
20
30
public class SimilarityExporterTest {
21
31
@ Rule
22
32
public final ImpermanentDatabaseRule DB = new ImpermanentDatabaseRule ();
23
33
34
+ private static final MethodHandles .Lookup LOOKUP = MethodHandles .lookup ();
35
+ private static final MethodType CTOR_METHOD = MethodType .methodType (
36
+ void .class ,
37
+ GraphDatabaseAPI .class ,
38
+ Log .class ,
39
+ String .class ,
40
+ String .class ,
41
+ int .class );
42
+
24
43
private static final String RELATIONSHIP_TYPE = "SIMILAR" ;
25
44
private static final String PROPERTY_NAME = "score" ;
45
+ private SimilarityExporter exporter ;
46
+ private GraphDatabaseAPI api ;
47
+ private Class <? extends SimilarityExporter > similarityExporterFactory ;
48
+
49
+ @ Parameterized .Parameters (name = "{1}" )
50
+ public static Collection <Object []> data () {
51
+ return Arrays .asList (
52
+ new Object []{SequentialSimilarityExporter .class , "Sequential" },
53
+ new Object []{ParallelSimilarityExporter .class , "Parallel" }
54
+ );
55
+ }
26
56
27
- @ Test
28
- public void createNothing () {
29
- GraphDatabaseAPI api = DB .getGraphDatabaseAPI ();
30
- createNodes (api , 2 );
57
+ @ Before
58
+ public void setup () {
59
+ api = DB .getGraphDatabaseAPI ();
60
+ }
61
+
62
+ public SimilarityExporterTest (Class <? extends SimilarityExporter > similarityExporterFactory ,
63
+ String ignoreParamOnlyForTestNaming ) throws Throwable {
64
+
65
+ this .similarityExporterFactory = similarityExporterFactory ;
66
+ }
31
67
32
- SequentialSimilarityExporter exporter = new SequentialSimilarityExporter (api , NullLog .getInstance (), RELATIONSHIP_TYPE , PROPERTY_NAME );
68
+ public SimilarityExporter load (Class <? extends SimilarityExporter > factoryType , int nodeCount ) throws Throwable {
69
+ final MethodHandle constructor = findConstructor (factoryType );
70
+ return (SimilarityExporter ) constructor .invoke (api , NullLog .getInstance (), RELATIONSHIP_TYPE , PROPERTY_NAME , nodeCount );
71
+ }
72
+
73
+ private MethodHandle findConstructor (Class <?> factoryType ) {
74
+ try {
75
+ return LOOKUP .findConstructor (factoryType , CTOR_METHOD );
76
+ } catch (NoSuchMethodException | IllegalAccessException e ) {
77
+ throw new RuntimeException (e );
78
+ }
79
+ }
80
+
81
+ @ Test
82
+ public void createNothing () throws Throwable {
83
+ int nodeCount = 2 ;
84
+ createNodes (api , nodeCount );
85
+ exporter = load (similarityExporterFactory , nodeCount );
33
86
34
87
Stream <SimilarityResult > similarityPairs = Stream .empty ();
35
88
@@ -43,11 +96,10 @@ public void createNothing() {
43
96
}
44
97
45
98
@ Test
46
- public void createOneRelationship () {
47
- GraphDatabaseAPI api = DB .getGraphDatabaseAPI ();
48
- createNodes (api , 2 );
49
-
50
- SequentialSimilarityExporter exporter = new SequentialSimilarityExporter (api , NullLog .getInstance (), RELATIONSHIP_TYPE , PROPERTY_NAME );
99
+ public void createOneRelationship () throws Throwable {
100
+ int nodeCount = 2 ;
101
+ createNodes (api , nodeCount );
102
+ exporter = load (similarityExporterFactory , nodeCount );
51
103
52
104
Stream <SimilarityResult > similarityPairs = Stream .of (new SimilarityResult (0 , 1 , -1 , -1 , -1 , 0.5 ));
53
105
@@ -62,11 +114,12 @@ public void createOneRelationship() {
62
114
}
63
115
64
116
@ Test
65
- public void multipleBatches () {
66
- GraphDatabaseAPI api = DB .getGraphDatabaseAPI ();
67
- createNodes (api , 4 );
117
+ public void multipleBatches () throws Throwable {
118
+ int nodeCount = 4 ;
119
+ createNodes (api , nodeCount );
120
+ exporter = load (similarityExporterFactory , nodeCount );
68
121
69
- SequentialSimilarityExporter exporter = new SequentialSimilarityExporter (api , NullLog .getInstance (), RELATIONSHIP_TYPE , PROPERTY_NAME );
122
+ SimilarityExporter exporter = new SequentialSimilarityExporter (api , NullLog .getInstance (), RELATIONSHIP_TYPE , PROPERTY_NAME , 4 );
70
123
71
124
Stream <SimilarityResult > similarityPairs = Stream .of (
72
125
new SimilarityResult (0 , 1 , -1 , -1 , -1 , 0.5 ),
@@ -86,16 +139,16 @@ public void multipleBatches() {
86
139
}
87
140
88
141
@ Test
89
- public void smallerThanBatchSize () {
90
- GraphDatabaseAPI api = DB .getGraphDatabaseAPI ();
91
- createNodes (api , 5 );
92
-
93
- SequentialSimilarityExporter exporter = new SequentialSimilarityExporter (api , NullLog .getInstance (), RELATIONSHIP_TYPE , PROPERTY_NAME );
142
+ public void smallerThanBatchSize () throws Throwable {
143
+ int nodeCount = 5 ;
144
+ createNodes (api , nodeCount );
145
+ exporter = load (similarityExporterFactory , nodeCount );
94
146
95
147
Stream <SimilarityResult > similarityPairs = Stream .of (
96
148
new SimilarityResult (0 , 1 , -1 , -1 , -1 , 0.5 ),
149
+ new SimilarityResult (1 , 2 , -1 , -1 , -1 , 0.6 ),
97
150
new SimilarityResult (2 , 3 , -1 , -1 , -1 , 0.7 ),
98
- new SimilarityResult (3 , 4 , -1 , -1 , -1 , 0.7 )
151
+ new SimilarityResult (3 , 4 , -1 , -1 , -1 , 0.8 )
99
152
);
100
153
101
154
int batches = exporter .export (similarityPairs , 10 );
@@ -104,10 +157,11 @@ public void smallerThanBatchSize() {
104
157
try (Transaction tx = api .beginTx ()) {
105
158
List <SimilarityRelationship > allRelationships = getSimilarityRelationships (api );
106
159
107
- assertThat (allRelationships , hasSize (3 ));
160
+ assertThat (allRelationships , hasSize (4 ));
108
161
assertThat (allRelationships , hasItems (new SimilarityRelationship (0 , 1 , 0.5 )));
162
+ assertThat (allRelationships , hasItems (new SimilarityRelationship (1 , 2 , 0.6 )));
109
163
assertThat (allRelationships , hasItems (new SimilarityRelationship (2 , 3 , 0.7 )));
110
- assertThat (allRelationships , hasItems (new SimilarityRelationship (3 , 4 , 0.7 )));
164
+ assertThat (allRelationships , hasItems (new SimilarityRelationship (3 , 4 , 0.8 )));
111
165
}
112
166
}
113
167
0 commit comments