diff --git a/src/main/java/org/jboss/netty/handler/codec/serialization/CompactObjectInputStream.java b/src/main/java/org/jboss/netty/handler/codec/serialization/CompactObjectInputStream.java index 1d99bff7f3..c3f1afe557 100644 --- a/src/main/java/org/jboss/netty/handler/codec/serialization/CompactObjectInputStream.java +++ b/src/main/java/org/jboss/netty/handler/codec/serialization/CompactObjectInputStream.java @@ -21,6 +21,8 @@ import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectStreamClass; import java.io.StreamCorruptedException; +import java.util.HashMap; +import java.util.Map; /** * @author The Netty Project @@ -31,6 +33,7 @@ import java.io.StreamCorruptedException; */ class CompactObjectInputStream extends ObjectInputStream { + private final Map> classCache = new HashMap>(); private final ClassLoader classLoader; CompactObjectInputStream(InputStream in) throws IOException { @@ -65,7 +68,7 @@ class CompactObjectInputStream extends ObjectInputStream { case CompactObjectOutputStream.TYPE_THIN_DESCRIPTOR: String className = readUTF(); Class clazz = loadClass(className); - return ObjectStreamClass.lookup(clazz); + return ObjectStreamClass.lookupAny(clazz); default: throw new StreamCorruptedException( "Unexpected class descriptor type: " + type); @@ -74,18 +77,33 @@ class CompactObjectInputStream extends ObjectInputStream { @Override protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + // Query the cache first. String className = desc.getName(); - try { - return loadClass(className); - } catch (ClassNotFoundException ex) { - // FIXME Cache the result to avoid the cost of ClassNotFoundException. - return super.resolveClass(desc); + Class clazz = classCache.get(className); + if (clazz != null) { + return clazz; } + + // And then try to resolve. + try { + clazz = loadClass(className); + } catch (ClassNotFoundException ex) { + clazz = super.resolveClass(desc); + } + + classCache.put(className, clazz); + return clazz; } protected Class loadClass(String className) throws ClassNotFoundException { - // FIXME Cache the result to avoid the cost of loading classes. + // Query the cache first. Class clazz; + clazz = classCache.get(className); + if (clazz != null) { + return clazz; + } + + // And then try to load. ClassLoader classLoader = this.classLoader; if (classLoader == null) { classLoader = Thread.currentThread().getContextClassLoader(); @@ -96,6 +114,8 @@ class CompactObjectInputStream extends ObjectInputStream { } else { clazz = Class.forName(className); } + + classCache.put(className, clazz); return clazz; } } diff --git a/src/test/java/org/jboss/netty/handler/codec/serialization/CompactObjectSerializationTest.java b/src/test/java/org/jboss/netty/handler/codec/serialization/CompactObjectSerializationTest.java new file mode 100644 index 0000000000..0db5000c7c --- /dev/null +++ b/src/test/java/org/jboss/netty/handler/codec/serialization/CompactObjectSerializationTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2011 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version 2.0 + * (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package org.jboss.netty.handler.codec.serialization; + +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.List; + +import org.junit.Assert; +import org.junit.Test; + +/** + * @author Trustin Lee + */ +public class CompactObjectSerializationTest { + + @Test + public void testInterfaceSerialization() throws Exception { + PipedOutputStream pipeOut = new PipedOutputStream(); + PipedInputStream pipeIn = new PipedInputStream(pipeOut); + CompactObjectOutputStream out = new CompactObjectOutputStream(pipeOut); + CompactObjectInputStream in = new CompactObjectInputStream(pipeIn); + out.writeObject(List.class); + Assert.assertSame(List.class, in.readObject()); + } +}