Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Commit 17d4502

Browse files
committed
parameterise similarity test
1 parent 87df54f commit 17d4502

File tree

4 files changed

+83
-24
lines changed

4 files changed

+83
-24
lines changed

algo/src/main/java/org/neo4j/graphalgo/similarity/ParallelSimilarityExporter.java

+5
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ public int export(Stream<SimilarityResult> similarityPairs, long batchSize) {
9090
.stream();
9191

9292
int queueSize = dssResult.getSetCount();
93+
94+
if(queueSize == 0) {
95+
return 0;
96+
}
97+
9398
log.info("ParallelSimilarityExporter: Relationships to be created: %d, Partitions found: %d", numberOfRelationships[0], queueSize);
9499

95100
ArrayBlockingQueue<List<SimilarityResult>> outQueue = new ArrayBlockingQueue<>(queueSize);

algo/src/main/java/org/neo4j/graphalgo/similarity/SequentialSimilarityExporter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public class SequentialSimilarityExporter extends StatementApi implements Simila
4242

4343
public SequentialSimilarityExporter(GraphDatabaseAPI api,
4444
Log log, String relationshipType,
45-
String propertyName) {
45+
String propertyName, int nodeCount) {
4646
super(api);
4747
this.log = log;
4848
propertyId = getOrCreatePropertyId(propertyName);

algo/src/main/java/org/neo4j/graphalgo/similarity/SimilarityProc.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ Stream<SimilaritySummaryResult> writeAndAggregateResults(Stream<SimilarityResult
161161

162162
} else {
163163
try (ProgressTimer timer = builder.timeWrite()) {
164-
SequentialSimilarityExporter similarityExporter = new SequentialSimilarityExporter(api, log, writeRelationshipType, writeProperty);
164+
SequentialSimilarityExporter similarityExporter = new SequentialSimilarityExporter(api, log, writeRelationshipType, writeProperty, length);
165165
similarityExporter.export(stream.peek(recorder), writeBatchSize);
166166
}
167167
}

tests/src/test/java/org/neo4j/graphalgo/similarity/SimilarityExporterTest.java

+76-22
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
package org.neo4j.graphalgo.similarity;
22

3+
import org.junit.Before;
34
import org.junit.Rule;
45
import org.junit.Test;
6+
import org.junit.runner.RunWith;
7+
import org.junit.runners.Parameterized;
58
import org.neo4j.graphdb.Transaction;
69
import org.neo4j.kernel.internal.GraphDatabaseAPI;
10+
import org.neo4j.logging.Log;
711
import org.neo4j.logging.NullLog;
812
import org.neo4j.test.rule.ImpermanentDatabaseRule;
913

14+
import java.lang.invoke.MethodHandle;
15+
import java.lang.invoke.MethodHandles;
16+
import java.lang.invoke.MethodType;
17+
import java.util.Arrays;
18+
import java.util.Collection;
1019
import java.util.List;
1120
import java.util.Objects;
1221
import java.util.stream.Collectors;
@@ -17,19 +26,63 @@
1726
import static org.junit.Assert.assertEquals;
1827
import static org.junit.Assert.assertThat;
1928

29+
@RunWith(Parameterized.class)
2030
public class SimilarityExporterTest {
2131
@Rule
2232
public final ImpermanentDatabaseRule DB = new ImpermanentDatabaseRule();
2333

34+
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
35+
private static final MethodType CTOR_METHOD = MethodType.methodType(
36+
void.class,
37+
GraphDatabaseAPI.class,
38+
Log.class,
39+
String.class,
40+
String.class,
41+
int.class);
42+
2443
private static final String RELATIONSHIP_TYPE = "SIMILAR";
2544
private static final String PROPERTY_NAME = "score";
45+
private SimilarityExporter exporter;
46+
private GraphDatabaseAPI api;
47+
private Class<? extends SimilarityExporter> similarityExporterFactory;
48+
49+
@Parameterized.Parameters(name = "{1}")
50+
public static Collection<Object[]> data() {
51+
return Arrays.asList(
52+
new Object[]{SequentialSimilarityExporter.class, "Sequential"},
53+
new Object[]{ParallelSimilarityExporter.class, "Parallel"}
54+
);
55+
}
2656

27-
@Test
28-
public void createNothing() {
29-
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
30-
createNodes(api, 2);
57+
@Before
58+
public void setup() {
59+
api = DB.getGraphDatabaseAPI();
60+
}
61+
62+
public SimilarityExporterTest(Class<? extends SimilarityExporter> similarityExporterFactory,
63+
String ignoreParamOnlyForTestNaming) throws Throwable {
64+
65+
this.similarityExporterFactory = similarityExporterFactory;
66+
}
3167

32-
SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
68+
public SimilarityExporter load(Class<? extends SimilarityExporter> factoryType, int nodeCount) throws Throwable {
69+
final MethodHandle constructor = findConstructor(factoryType);
70+
return (SimilarityExporter) constructor.invoke(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, nodeCount);
71+
}
72+
73+
private MethodHandle findConstructor(Class<?> factoryType) {
74+
try {
75+
return LOOKUP.findConstructor(factoryType, CTOR_METHOD);
76+
} catch (NoSuchMethodException | IllegalAccessException e) {
77+
throw new RuntimeException(e);
78+
}
79+
}
80+
81+
@Test
82+
public void createNothing() throws Throwable {
83+
int nodeCount = 2;
84+
createNodes(api, nodeCount);
85+
exporter = load(similarityExporterFactory, nodeCount);
3386

3487
Stream<SimilarityResult> similarityPairs = Stream.empty();
3588

@@ -43,11 +96,10 @@ public void createNothing() {
4396
}
4497

4598
@Test
46-
public void createOneRelationship() {
47-
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
48-
createNodes(api, 2);
49-
50-
SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
99+
public void createOneRelationship() throws Throwable {
100+
int nodeCount = 2;
101+
createNodes(api, nodeCount);
102+
exporter = load(similarityExporterFactory, nodeCount);
51103

52104
Stream<SimilarityResult> similarityPairs = Stream.of(new SimilarityResult(0, 1, -1, -1, -1, 0.5));
53105

@@ -62,11 +114,12 @@ public void createOneRelationship() {
62114
}
63115

64116
@Test
65-
public void multipleBatches() {
66-
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
67-
createNodes(api, 4);
117+
public void multipleBatches() throws Throwable {
118+
int nodeCount = 4;
119+
createNodes(api, nodeCount);
120+
exporter = load(similarityExporterFactory, nodeCount);
68121

69-
SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
122+
SimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME, 4);
70123

71124
Stream<SimilarityResult> similarityPairs = Stream.of(
72125
new SimilarityResult(0, 1, -1, -1, -1, 0.5),
@@ -86,16 +139,16 @@ public void multipleBatches() {
86139
}
87140

88141
@Test
89-
public void smallerThanBatchSize() {
90-
GraphDatabaseAPI api = DB.getGraphDatabaseAPI();
91-
createNodes(api, 5);
92-
93-
SequentialSimilarityExporter exporter = new SequentialSimilarityExporter(api, NullLog.getInstance(), RELATIONSHIP_TYPE, PROPERTY_NAME);
142+
public void smallerThanBatchSize() throws Throwable {
143+
int nodeCount = 5;
144+
createNodes(api, nodeCount);
145+
exporter = load(similarityExporterFactory, nodeCount);
94146

95147
Stream<SimilarityResult> similarityPairs = Stream.of(
96148
new SimilarityResult(0, 1, -1, -1, -1, 0.5),
149+
new SimilarityResult(1, 2, -1, -1, -1, 0.6),
97150
new SimilarityResult(2, 3, -1, -1, -1, 0.7),
98-
new SimilarityResult(3, 4, -1, -1, -1, 0.7)
151+
new SimilarityResult(3, 4, -1, -1, -1, 0.8)
99152
);
100153

101154
int batches = exporter.export(similarityPairs, 10);
@@ -104,10 +157,11 @@ public void smallerThanBatchSize() {
104157
try (Transaction tx = api.beginTx()) {
105158
List<SimilarityRelationship> allRelationships = getSimilarityRelationships(api);
106159

107-
assertThat(allRelationships, hasSize(3));
160+
assertThat(allRelationships, hasSize(4));
108161
assertThat(allRelationships, hasItems(new SimilarityRelationship(0, 1, 0.5)));
162+
assertThat(allRelationships, hasItems(new SimilarityRelationship(1, 2, 0.6)));
109163
assertThat(allRelationships, hasItems(new SimilarityRelationship(2, 3, 0.7)));
110-
assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.7)));
164+
assertThat(allRelationships, hasItems(new SimilarityRelationship(3, 4, 0.8)));
111165
}
112166
}
113167

0 commit comments

Comments
 (0)