|
1 | 1 | /* |
2 | | - * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. |
| 2 | + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. |
3 | 3 | */ |
4 | 4 |
|
5 | 5 | package kotlinx.coroutines |
6 | 6 |
|
7 | 7 | import java.lang.reflect.* |
8 | 8 | import java.util.* |
9 | 9 | import java.util.Collections.* |
| 10 | +import java.util.concurrent.atomic.* |
10 | 11 | import kotlin.collections.ArrayList |
| 12 | +import kotlin.test.* |
11 | 13 |
|
12 | 14 | object FieldWalker { |
| 15 | + sealed class Ref { |
| 16 | + object RootRef : Ref() |
| 17 | + class FieldRef(val parent: Any, val name: String) : Ref() |
| 18 | + class ArrayRef(val parent: Any, val index: Int) : Ref() |
| 19 | + } |
| 20 | + |
| 21 | + private val fieldsCache = HashMap<Class<*>, List<Field>>() |
| 22 | + |
| 23 | + init { |
| 24 | + // excluded/terminal classes (don't walk them) |
| 25 | + fieldsCache += listOf(Any::class, String::class, Thread::class, Throwable::class) |
| 26 | + .map { it.java } |
| 27 | + .associateWith { emptyList<Field>() } |
| 28 | + } |
13 | 29 |
|
14 | 30 | /* |
15 | 31 | * Reflectively starts to walk through object graph and returns identity set of all reachable objects. |
| 32 | + * Use [walkRefs] if you need a path from root for debugging. |
| 33 | + */ |
| 34 | + public fun walk(root: Any?): Set<Any> = walkRefs(root).keys |
| 35 | + |
| 36 | + public fun assertReachableCount(expected: Int, root: Any?, predicate: (Any) -> Boolean) { |
| 37 | + val visited = walkRefs(root) |
| 38 | + val actual = visited.keys.filter(predicate) |
| 39 | + if (actual.size != expected) { |
| 40 | + val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) } |
| 41 | + assertEquals( |
| 42 | + expected, actual.size, |
| 43 | + "Unexpected number objects. Expected $expected, found ${actual.size}$textDump" |
| 44 | + ) |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + /* |
| 49 | + * Reflectively starts to walk through object graph and map to all the reached object to their path |
| 50 | + * in from root. Use [showPath] do display a path if needed. |
16 | 51 | */ |
17 | | - public fun walk(root: Any): Set<Any> { |
18 | | - val result = newSetFromMap<Any>(IdentityHashMap()) |
19 | | - result.add(root) |
| 52 | + private fun walkRefs(root: Any?): Map<Any, Ref> { |
| 53 | + val visited = IdentityHashMap<Any, Ref>() |
| 54 | + if (root == null) return visited |
| 55 | + visited[root] = Ref.RootRef |
20 | 56 | val stack = ArrayDeque<Any>() |
21 | 57 | stack.addLast(root) |
22 | 58 | while (stack.isNotEmpty()) { |
23 | 59 | val element = stack.removeLast() |
24 | | - val type = element.javaClass |
25 | | - type.visit(element, result, stack) |
| 60 | + try { |
| 61 | + visit(element, visited, stack) |
| 62 | + } catch (e: Exception) { |
| 63 | + error("Failed to visit element ${showPath(element, visited)}: $e") |
| 64 | + } |
26 | 65 | } |
27 | | - return result |
| 66 | + return visited |
28 | 67 | } |
29 | 68 |
|
30 | | - private fun Class<*>.visit( |
31 | | - element: Any, |
32 | | - result: MutableSet<Any>, |
33 | | - stack: ArrayDeque<Any> |
34 | | - ) { |
35 | | - val fields = fields() |
36 | | - fields.forEach { |
37 | | - it.isAccessible = true |
38 | | - val value = it.get(element) ?: return@forEach |
39 | | - if (result.add(value)) { |
40 | | - stack.addLast(value) |
| 69 | + private fun showPath(element: Any, visited: Map<Any, Ref>): String { |
| 70 | + val path = ArrayList<String>() |
| 71 | + var cur = element |
| 72 | + while (true) { |
| 73 | + val ref = visited.getValue(cur) |
| 74 | + if (ref is Ref.RootRef) break |
| 75 | + when (ref) { |
| 76 | + is Ref.FieldRef -> { |
| 77 | + cur = ref.parent |
| 78 | + path += ".${ref.name}" |
| 79 | + } |
| 80 | + is Ref.ArrayRef -> { |
| 81 | + cur = ref.parent |
| 82 | + path += "[${ref.index}]" |
| 83 | + } |
41 | 84 | } |
42 | 85 | } |
| 86 | + path.reverse() |
| 87 | + return path.joinToString("") |
| 88 | + } |
43 | 89 |
|
44 | | - if (isArray && !componentType.isPrimitive) { |
45 | | - val array = element as Array<Any?> |
46 | | - array.filterNotNull().forEach { |
47 | | - if (result.add(it)) { |
48 | | - stack.addLast(it) |
| 90 | + private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>) { |
| 91 | + val type = element.javaClass |
| 92 | + when { |
| 93 | + // Special code for arrays |
| 94 | + type.isArray && !type.componentType.isPrimitive -> { |
| 95 | + @Suppress("UNCHECKED_CAST") |
| 96 | + val array = element as Array<Any?> |
| 97 | + array.forEachIndexed { index, value -> |
| 98 | + push(value, visited, stack) { Ref.ArrayRef(element, index) } |
| 99 | + } |
| 100 | + } |
| 101 | + // Special code for platform types that cannot be reflectively accessed on modern JDKs |
| 102 | + type.name.startsWith("java.") && element is Collection<*> -> { |
| 103 | + element.forEachIndexed { index, value -> |
| 104 | + push(value, visited, stack) { Ref.ArrayRef(element, index) } |
| 105 | + } |
| 106 | + } |
| 107 | + type.name.startsWith("java.") && element is Map<*, *> -> { |
| 108 | + push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") } |
| 109 | + push(element.values, visited, stack) { Ref.FieldRef(element, "values") } |
| 110 | + } |
| 111 | + element is AtomicReference<*> -> { |
| 112 | + push(element.get(), visited, stack) { Ref.FieldRef(element, "value") } |
| 113 | + } |
| 114 | + // All the other classes are reflectively scanned |
| 115 | + else -> fields(type).forEach { field -> |
| 116 | + push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) } |
| 117 | + // special case to scan Throwable cause (cannot get it reflectively) |
| 118 | + if (element is Throwable) { |
| 119 | + push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") } |
49 | 120 | } |
50 | 121 | } |
51 | 122 | } |
52 | 123 | } |
53 | 124 |
|
54 | | - private fun Class<*>.fields(): List<Field> { |
| 125 | + private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) { |
| 126 | + if (value != null && !visited.containsKey(value)) { |
| 127 | + visited[value] = ref() |
| 128 | + stack.addLast(value) |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + private fun fields(type0: Class<*>): List<Field> { |
| 133 | + fieldsCache[type0]?.let { return it } |
55 | 134 | val result = ArrayList<Field>() |
56 | | - var type = this |
57 | | - while (type != Any::class.java) { |
| 135 | + var type = type0 |
| 136 | + while (true) { |
58 | 137 | val fields = type.declaredFields.filter { |
59 | 138 | !it.type.isPrimitive |
60 | 139 | && !Modifier.isStatic(it.modifiers) |
61 | 140 | && !(it.type.isArray && it.type.componentType.isPrimitive) |
62 | 141 | } |
| 142 | + fields.forEach { it.isAccessible = true } // make them all accessible |
63 | 143 | result.addAll(fields) |
64 | 144 | type = type.superclass |
65 | | - } |
66 | | - |
67 | | - return result |
68 | | - } |
69 | | - |
70 | | - // Debugging-only |
71 | | - @Suppress("UNUSED") |
72 | | - fun printPath(from: Any, to: Any) { |
73 | | - val pathNodes = ArrayList<String>() |
74 | | - val visited = newSetFromMap<Any>(IdentityHashMap()) |
75 | | - visited.add(from) |
76 | | - if (findPath(from, to, visited, pathNodes)) { |
77 | | - pathNodes.reverse() |
78 | | - println(pathNodes.joinToString(" -> ", from.javaClass.simpleName + " -> ", "-> " + to.javaClass.simpleName)) |
79 | | - } else { |
80 | | - println("Path from $from to $to not found") |
81 | | - } |
82 | | - } |
83 | | - |
84 | | - private fun findPath(from: Any, to: Any, visited: MutableSet<Any>, pathNodes: MutableList<String>): Boolean { |
85 | | - if (from === to) { |
86 | | - return true |
87 | | - } |
88 | | - |
89 | | - val type = from.javaClass |
90 | | - if (type.isArray) { |
91 | | - if (type.componentType.isPrimitive) return false |
92 | | - val array = from as Array<Any?> |
93 | | - array.filterNotNull().forEach { |
94 | | - if (findPath(it, to, visited, pathNodes)) { |
95 | | - return true |
96 | | - } |
| 145 | + val superFields = fieldsCache[type] // will stop at Any anyway |
| 146 | + if (superFields != null) { |
| 147 | + result.addAll(superFields) |
| 148 | + break |
97 | 149 | } |
98 | | - return false |
99 | 150 | } |
100 | | - |
101 | | - val fields = type.fields() |
102 | | - fields.forEach { |
103 | | - it.isAccessible = true |
104 | | - val value = it.get(from) ?: return@forEach |
105 | | - if (!visited.add(value)) return@forEach |
106 | | - val found = findPath(value, to, visited, pathNodes) |
107 | | - if (found) { |
108 | | - pathNodes += from.javaClass.simpleName + ":" + it.name |
109 | | - return true |
110 | | - } |
111 | | - } |
112 | | - |
113 | | - return false |
| 151 | + fieldsCache[type0] = result |
| 152 | + return result |
114 | 153 | } |
115 | 154 | } |
0 commit comments