Skip to content

Commit 4528d53

Browse files
Equal PersistentOrderedSets are not equal (#217)
Fixes tree invariant and Iterations issues Fixes #204
1 parent 6e66748 commit 4528d53

File tree

7 files changed

+362
-25
lines changed

7 files changed

+362
-25
lines changed

core/commonMain/src/implementations/immutableMap/PersistentHashMapBuilderContentIterators.kt

+18-3
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ internal open class PersistentHashMapBuilderBaseIterator<K, V, T>(
5757
val currentKey = currentKey()
5858

5959
builder.remove(lastIteratedKey)
60-
resetPath(currentKey.hashCode(), builder.node, currentKey, 0)
60+
resetPath(currentKey.hashCode(), builder.node, currentKey, 0, lastIteratedKey.hashCode(), afterRemove = true)
6161
} else {
6262
builder.remove(lastIteratedKey)
6363
}
@@ -82,7 +82,7 @@ internal open class PersistentHashMapBuilderBaseIterator<K, V, T>(
8282
expectedModCount = builder.modCount
8383
}
8484

85-
private fun resetPath(keyHash: Int, node: TrieNode<*, *>, key: K, pathIndex: Int) {
85+
private fun resetPath(keyHash: Int, node: TrieNode<*, *>, key: K, pathIndex: Int, removedKeyHash: Int = 0, afterRemove: Boolean = false) {
8686
val shift = pathIndex * LOG_MAX_BRANCHING_FACTOR
8787

8888
if (shift > MAX_SHIFT) { // collision
@@ -99,6 +99,21 @@ internal open class PersistentHashMapBuilderBaseIterator<K, V, T>(
9999
if (node.hasEntryAt(keyPositionMask)) { // key is directly in buffer
100100
val keyIndex = node.entryKeyIndex(keyPositionMask)
101101

102+
// After removing an element, we need to handle node promotion properly to maintain a correct iteration order.
103+
// `removedKeyPositionMask` represents the bit position of the removed key's hash at the current level.
104+
// This is needed to detect if the current key was potentially promoted from a deeper level.
105+
val removedKeyPositionMask = if (afterRemove) 1 shl indexSegment(removedKeyHash, shift) else 0
106+
107+
// Check if the removed key is at the same position as the current key and was previously at a deeper level.
108+
// This indicates a node promotion occurred during removal,
109+
// and we need to handle it in a special way to prevent re-traversing already visited elements.
110+
if (keyPositionMask == removedKeyPositionMask && pathIndex < pathLastIndex) {
111+
// Instead of traversing the normal way, we create a special path entry at the previous depth
112+
// that points directly to the promoted entry, maintaining the original iteration sequence.
113+
path[pathLastIndex].reset(arrayOf(node.buffer[keyIndex], node.buffer[keyIndex + 1]), ENTRY_SIZE)
114+
return
115+
}
116+
102117
// assert(node.keyAtIndex(keyIndex) == key)
103118

104119
path[pathIndex].reset(node.buffer, ENTRY_SIZE * node.entryCount(), keyIndex)
@@ -111,7 +126,7 @@ internal open class PersistentHashMapBuilderBaseIterator<K, V, T>(
111126
val nodeIndex = node.nodeIndex(keyPositionMask)
112127
val targetNode = node.nodeAtIndex(nodeIndex)
113128
path[pathIndex].reset(node.buffer, ENTRY_SIZE * node.entryCount(), nodeIndex)
114-
resetPath(keyHash, targetNode, key, pathIndex + 1)
129+
resetPath(keyHash, targetNode, key, pathIndex + 1, removedKeyHash, afterRemove)
115130
}
116131

117132
private fun checkNextWasInvoked() {

core/commonMain/src/implementations/immutableMap/TrieNode.kt

+6-22
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ internal class TrieNode<K, V>(
180180
}
181181

182182
/** The given [newNode] must not be a part of any persistent map instance. */
183-
private fun updateNodeAtIndex(nodeIndex: Int, positionMask: Int, newNode: TrieNode<K, V>): TrieNode<K, V> {
183+
private fun updateNodeAtIndex(nodeIndex: Int, positionMask: Int, newNode: TrieNode<K, V>, owner: MutabilityOwnership? = null): TrieNode<K, V> {
184184
// assert(buffer[nodeIndex] !== newNode)
185185
val newNodeBuffer = newNode.buffer
186186
if (newNodeBuffer.size == 2 && newNode.nodeMap == 0) {
@@ -192,30 +192,14 @@ internal class TrieNode<K, V>(
192192

193193
val keyIndex = entryKeyIndex(positionMask)
194194
val newBuffer = buffer.replaceNodeWithEntry(nodeIndex, keyIndex, newNodeBuffer[0], newNodeBuffer[1])
195-
return TrieNode(dataMap xor positionMask, nodeMap xor positionMask, newBuffer)
195+
return TrieNode(dataMap xor positionMask, nodeMap xor positionMask, newBuffer, owner)
196196
}
197197

198-
val newBuffer = buffer.copyOf(buffer.size)
199-
newBuffer[nodeIndex] = newNode
200-
return TrieNode(dataMap, nodeMap, newBuffer)
201-
}
202-
203-
/** The given [newNode] must not be a part of any persistent map instance. */
204-
private fun mutableUpdateNodeAtIndex(nodeIndex: Int, newNode: TrieNode<K, V>, owner: MutabilityOwnership): TrieNode<K, V> {
205-
assert(newNode.ownedBy === owner)
206-
// assert(buffer[nodeIndex] !== newNode)
207-
208-
// nodes (including collision nodes) that have only one entry are upped if they have no siblings
209-
if (buffer.size == 1 && newNode.buffer.size == ENTRY_SIZE && newNode.nodeMap == 0) {
210-
// assert(dataMap == 0 && nodeMap xor positionMask == 0)
211-
newNode.dataMap = nodeMap
212-
return newNode
213-
}
214-
215-
if (ownedBy === owner) {
198+
if (owner != null && ownedBy === owner) {
216199
buffer[nodeIndex] = newNode
217200
return this
218201
}
202+
219203
val newBuffer = buffer.copyOf()
220204
newBuffer[nodeIndex] = newNode
221205
return TrieNode(dataMap, nodeMap, newBuffer, owner)
@@ -716,7 +700,7 @@ internal class TrieNode<K, V>(
716700
if (targetNode === newNode) {
717701
return this
718702
}
719-
return mutableUpdateNodeAtIndex(nodeIndex, newNode, mutator.ownership)
703+
return updateNodeAtIndex(nodeIndex, keyPositionMask, newNode, mutator.ownership)
720704
}
721705

722706
// key is absent
@@ -791,7 +775,7 @@ internal class TrieNode<K, V>(
791775
newNode == null ->
792776
mutableRemoveNodeAtIndex(nodeIndex, positionMask, owner)
793777
targetNode !== newNode ->
794-
mutableUpdateNodeAtIndex(nodeIndex, newNode, owner)
778+
updateNodeAtIndex(nodeIndex, positionMask, newNode, owner)
795779
else -> this
796780
}
797781

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2016-2025 JetBrains s.r.o.
3+
* Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package tests.contract.map
7+
8+
import kotlinx.collections.immutable.implementations.immutableMap.PersistentHashMap
9+
import kotlinx.collections.immutable.persistentHashMapOf
10+
import tests.stress.IntWrapper
11+
import kotlin.collections.iterator
12+
import kotlin.test.Test
13+
import kotlin.test.assertEquals
14+
import kotlin.test.assertFailsWith
15+
import kotlin.test.assertFalse
16+
import kotlin.test.assertTrue
17+
18+
class PersistentHashMapBuilderTest {
19+
20+
@Test
21+
fun `should correctly iterate after removing integer key and promotion colliding key during iteration`() {
22+
val removedKey = 0
23+
val map: PersistentHashMap<Int, String> =
24+
persistentHashMapOf(1 to "a", 2 to "b", 3 to "c", removedKey to "y", 32 to "z")
25+
as PersistentHashMap<Int, String>
26+
27+
validatePromotion(map, removedKey)
28+
}
29+
30+
@Test
31+
fun `should correctly iterate after removing IntWrapper key and promotion colliding key during iteration`() {
32+
val removedKey = IntWrapper(0, 0)
33+
val map: PersistentHashMap<IntWrapper, String> = persistentHashMapOf(
34+
removedKey to "a",
35+
IntWrapper(1, 0) to "b",
36+
IntWrapper(2, 32) to "c",
37+
IntWrapper(3, 32) to "d"
38+
) as PersistentHashMap<IntWrapper, String>
39+
40+
validatePromotion(map, removedKey)
41+
}
42+
43+
private fun <K> validatePromotion(map: PersistentHashMap<K, *>, removedKey: K) {
44+
val builder = map.builder()
45+
val iterator = builder.entries.iterator()
46+
47+
val expectedCount = map.size
48+
var actualCount = 0
49+
50+
while (iterator.hasNext()) {
51+
val (key, _) = iterator.next()
52+
if (key == removedKey) {
53+
iterator.remove()
54+
}
55+
actualCount++
56+
}
57+
58+
val resultMap = builder.build()
59+
for ((key, value) in map) {
60+
if (key != removedKey) {
61+
assertTrue(key in resultMap)
62+
assertEquals(resultMap[key], value)
63+
} else {
64+
assertFalse(key in resultMap)
65+
}
66+
}
67+
68+
assertEquals(expectedCount, actualCount)
69+
}
70+
71+
@Test
72+
fun `removing twice on iterators throws IllegalStateException`() {
73+
val map: PersistentHashMap<Int, String> =
74+
persistentHashMapOf(1 to "a", 2 to "b", 3 to "c", 0 to "y", 32 to "z") as PersistentHashMap<Int, String>
75+
val builder = map.builder()
76+
val iterator = builder.entries.iterator()
77+
78+
assertFailsWith<IllegalStateException> {
79+
while (iterator.hasNext()) {
80+
val (key, _) = iterator.next()
81+
if (key == 0) iterator.remove()
82+
if (key == 0) {
83+
iterator.remove()
84+
iterator.remove()
85+
}
86+
}
87+
}
88+
}
89+
90+
@Test
91+
fun `removing elements from different iterators throws ConcurrentModificationException`() {
92+
val map: PersistentHashMap<Int, String> =
93+
persistentHashMapOf(1 to "a", 2 to "b", 3 to "c", 0 to "y", 32 to "z") as PersistentHashMap<Int, String>
94+
val builder = map.builder()
95+
val iterator1 = builder.entries.iterator()
96+
val iterator2 = builder.entries.iterator()
97+
98+
assertFailsWith<ConcurrentModificationException> {
99+
while (iterator1.hasNext()) {
100+
val (key, _) = iterator1.next()
101+
iterator2.next()
102+
if (key == 0) iterator1.remove()
103+
if (key == 2) iterator2.remove()
104+
}
105+
}
106+
}
107+
108+
@Test
109+
fun `removing element from one iterator and accessing another throws ConcurrentModificationException`() {
110+
val map = persistentHashMapOf(1 to "a", 2 to "b", 3 to "c")
111+
val builder = map.builder()
112+
val iterator1 = builder.entries.iterator()
113+
val iterator2 = builder.entries.iterator()
114+
115+
assertFailsWith<ConcurrentModificationException> {
116+
iterator1.next()
117+
iterator1.remove()
118+
iterator2.next()
119+
}
120+
}
121+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright 2016-2025 JetBrains s.r.o.
3+
* Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package tests.contract.map
7+
8+
import kotlinx.collections.immutable.implementations.immutableMap.PersistentHashMap
9+
import kotlinx.collections.immutable.persistentHashMapOf
10+
import kotlin.test.Test
11+
import kotlin.test.assertEquals
12+
import kotlin.test.assertTrue
13+
14+
class PersistentHashMapTest {
15+
16+
@Test
17+
fun `if the collision is of size 2 and one of the keys is removed the remaining key must be promoted`() {
18+
val map1: PersistentHashMap<Int, String> =
19+
persistentHashMapOf(-1 to "a", 0 to "b", 32 to "c") as PersistentHashMap<Int, String>
20+
val builder = map1.builder()
21+
val map2 = builder.build()
22+
23+
assertTrue(map1.equals(builder))
24+
assertEquals(map1, map2.toMap())
25+
assertEquals(map1, map2)
26+
27+
val map3 = map1.remove(0)
28+
builder.remove(0)
29+
val map4 = builder.build()
30+
31+
assertTrue(map3.equals(builder))
32+
assertEquals(map3, map4.toMap())
33+
assertEquals(map3, map4)
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright 2016-2025 JetBrains s.r.o.
3+
* Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file.
4+
*/
5+
6+
package tests.contract.set
7+
8+
import kotlinx.collections.immutable.implementations.immutableSet.PersistentHashSet
9+
import kotlinx.collections.immutable.persistentHashSetOf
10+
import tests.stress.IntWrapper
11+
import kotlin.test.Test
12+
import kotlin.test.assertEquals
13+
import kotlin.test.assertFailsWith
14+
import kotlin.test.assertFalse
15+
import kotlin.test.assertTrue
16+
17+
class PersistentHashSetBuilderTest {
18+
19+
@Test
20+
fun `should correctly iterate after removing integer element`() {
21+
val removedElement = 0
22+
val set: PersistentHashSet<Int> =
23+
persistentHashSetOf(1, 2, 3, removedElement, 32)
24+
as PersistentHashSet<Int>
25+
26+
validate(set, removedElement)
27+
}
28+
29+
@Test
30+
fun `should correctly iterate after removing IntWrapper element`() {
31+
val removedElement = IntWrapper(0, 0)
32+
val set: PersistentHashSet<IntWrapper> = persistentHashSetOf(
33+
removedElement,
34+
IntWrapper(1, 0),
35+
IntWrapper(2, 32),
36+
IntWrapper(3, 32)
37+
) as PersistentHashSet<IntWrapper>
38+
39+
validate(set, removedElement)
40+
}
41+
42+
private fun <E> validate(set: PersistentHashSet<E>, removedElement: E) {
43+
val builder = set.builder()
44+
val iterator = builder.iterator()
45+
46+
val expectedCount = set.size
47+
var actualCount = 0
48+
49+
while (iterator.hasNext()) {
50+
val element = iterator.next()
51+
if (element == removedElement) {
52+
iterator.remove()
53+
}
54+
actualCount++
55+
}
56+
57+
val resultSet = builder.build()
58+
for (element in set) {
59+
if (element != removedElement) {
60+
assertTrue(element in resultSet)
61+
} else {
62+
assertFalse(element in resultSet)
63+
}
64+
}
65+
66+
assertEquals(expectedCount, actualCount)
67+
}
68+
69+
@Test
70+
fun `removing twice on iterators throws IllegalStateException`() {
71+
val set: PersistentHashSet<Int> =
72+
persistentHashSetOf(1, 2, 3, 0, 32) as PersistentHashSet<Int>
73+
val builder = set.builder()
74+
val iterator = builder.iterator()
75+
76+
assertFailsWith<IllegalStateException> {
77+
while (iterator.hasNext()) {
78+
val element = iterator.next()
79+
if (element == 0) iterator.remove()
80+
if (element == 0) {
81+
iterator.remove()
82+
iterator.remove()
83+
}
84+
}
85+
}
86+
}
87+
88+
@Test
89+
fun `removing elements from different iterators throws ConcurrentModificationException`() {
90+
val set: PersistentHashSet<Int> =
91+
persistentHashSetOf(1, 2, 3, 0, 32) as PersistentHashSet<Int>
92+
val builder = set.builder()
93+
val iterator1 = builder.iterator()
94+
val iterator2 = builder.iterator()
95+
96+
assertFailsWith<ConcurrentModificationException> {
97+
while (iterator1.hasNext()) {
98+
val element1 = iterator1.next()
99+
iterator2.next()
100+
if (element1 == 0) iterator1.remove()
101+
if (element1 == 2) iterator2.remove()
102+
}
103+
}
104+
}
105+
106+
@Test
107+
fun `removing element from one iterator and accessing another throws ConcurrentModificationException`() {
108+
val set = persistentHashSetOf(1, 2, 3)
109+
val builder = set.builder()
110+
val iterator1 = builder.iterator()
111+
val iterator2 = builder.iterator()
112+
113+
assertFailsWith<ConcurrentModificationException> {
114+
iterator1.next()
115+
iterator1.remove()
116+
iterator2.next()
117+
}
118+
}
119+
}

0 commit comments

Comments
 (0)