1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.grouplens.grapht.util;
21
22 import org.apache.commons.lang3.ClassUtils;
23 import org.apache.commons.lang3.StringUtils;
24 import org.slf4j.Logger;
25 import org.slf4j.LoggerFactory;
26
27 import javax.annotation.Nullable;
28 import javax.annotation.concurrent.Immutable;
29 import javax.inject.Inject;
30 import java.io.IOException;
31 import java.io.ObjectInputStream;
32 import java.io.Serializable;
33 import java.lang.ref.WeakReference;
34 import java.lang.reflect.*;
35 import java.nio.ByteBuffer;
36 import java.nio.charset.Charset;
37 import java.security.MessageDigest;
38 import java.security.NoSuchAlgorithmException;
39 import java.util.*;
40
41
42
43
44
45
46
47
48
49
50
51
52 @Immutable
53 public final class ClassProxy implements Serializable {
54 private static final long serialVersionUID = 1;
55 private static final Logger logger = LoggerFactory.getLogger(ClassProxy.class);
56
57 private final String className;
58 private final long checksum;
59 @Nullable
60 private transient volatile WeakReference<Class<?>> theClass;
61 private transient ClassLoader classLoader;
62
63 private ClassProxy(String name, long check) {
64 className = name;
65 checksum = check;
66 classLoader = ClassLoaders.inferDefault(ClassProxy.class);
67 }
68
69 private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
70 stream.defaultReadObject();
71 classLoader = ClassLoaders.inferDefault(ClassProxy.class);
72 }
73
74
75
76
77
78 public String getClassName() {
79 return className;
80 }
81
82 @Override
83 public String toString() {
84 return "proxy of " + className;
85 }
86
87 @Override
88 public boolean equals(Object o) {
89 if (o == this) {
90 return true;
91 } else if (o instanceof ClassProxy) {
92 ClassProxy op = (ClassProxy) o;
93 return className.equals(op.className);
94 } else {
95 return false;
96 }
97 }
98
99 @Override
100 public int hashCode() {
101 return className.hashCode();
102 }
103
104
105
106
107
108
109 public Class<?> resolve() throws ClassNotFoundException {
110 WeakReference<Class<?>> ref = theClass;
111 Class<?> cls = ref == null ? null : ref.get();
112 if (cls == null) {
113 if(className.equals("void")) {
114
115 cls = Void.TYPE;
116 } else {
117 cls = ClassUtils.getClass(classLoader, className);
118 }
119 long check = checksumClass(cls);
120 if (!isSerializationPermissive() && checksum != check) {
121 throw new ClassNotFoundException("checksum mismatch for " + cls.getName());
122 } else {
123 if (checksum != check) {
124 logger.warn("checksum mismatch for {}", cls);
125 }
126 theClass = new WeakReference<Class<?>>(cls);
127 }
128 }
129 return cls;
130 }
131
132 private static final Map<Class<?>, ClassProxy> proxyCache = new WeakHashMap<Class<?>, ClassProxy>();
133
134
135
136
137
138
139
140 public static synchronized ClassProxy of(Class<?> cls) {
141 ClassProxy proxy = proxyCache.get(cls);
142 if (proxy == null) {
143 proxy = new ClassProxy(cls.getName(), checksumClass(cls));
144 proxy.theClass = new WeakReference<Class<?>>(cls);
145 proxyCache.put(cls, proxy);
146 }
147 return proxy;
148 }
149
150 private static final Charset UTF8 = Charset.forName("UTF-8");
151
152 public static boolean isSerializationPermissive() {
153 return Boolean.getBoolean("grapht.deserialization.permissive");
154 }
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169 private static long checksumClass(Class<?> type) {
170 MessageDigest digest;
171 try {
172 digest = MessageDigest.getInstance("MD5");
173 } catch (NoSuchAlgorithmException e) {
174 throw new RuntimeException("JVM does not support MD5", e);
175 }
176 checksumClass(type, digest);
177
178 ByteBuffer buf = ByteBuffer.wrap(digest.digest());
179 long l1 = buf.getLong();
180 long l2 = buf.getLong();
181 return l1 ^ l2;
182 }
183
184 private static void checksumClass(Class<?> type, MessageDigest digest) {
185
186
187 List<String> members = new ArrayList<String>();
188 for (Constructor<?> c: type.getDeclaredConstructors()) {
189 if (isInjectionSensitive(c)) {
190 members.add(String.format("%s(%s)", c.getName(),
191 StringUtils.join(c.getParameterTypes(), ", ")));
192 }
193 }
194 for (Method m: type.getDeclaredMethods()) {
195 if (isInjectionSensitive(m)) {
196 members.add(String.format("%s(%s): %s", m.getName(),
197 StringUtils.join(m.getParameterTypes(), ", "),
198 m.getReturnType()));
199 }
200 }
201 for (Field f: type.getDeclaredFields()) {
202 if (isInjectionSensitive(f)) {
203 members.add(f.getName() + ":" + f.getType().getName());
204 }
205 }
206
207 Collections.sort(members);
208
209 Class<?> sup = type.getSuperclass();
210 if (sup != null) {
211 checksumClass(sup, digest);
212 }
213 for (String mem: members) {
214 digest.update(mem.getBytes(UTF8));
215 }
216 }
217
218
219
220
221
222
223
224
225
226 private static <M extends Member & AnnotatedElement>boolean isInjectionSensitive(M m) {
227
228 if (Modifier.isStatic(m.getModifiers())) {
229 return false;
230 }
231
232
233 if (Modifier.isPrivate(m.getModifiers()) && m.getAnnotation(Inject.class) == null) {
234 return false;
235 }
236
237
238 return true;
239 }
240 }