001 /*
002 * $HeadURL: http://juliusdavies.ca/svn/not-yet-commons-ssl/tags/commons-ssl-0.3.11/src/java/org/apache/commons/ssl/RMISocketFactoryImpl.java $
003 * $Revision: 144 $
004 * $Date: 2009-05-25 11:14:29 -0700 (Mon, 25 May 2009) $
005 *
006 * ====================================================================
007 * Licensed to the Apache Software Foundation (ASF) under one
008 * or more contributor license agreements. See the NOTICE file
009 * distributed with this work for additional information
010 * regarding copyright ownership. The ASF licenses this file
011 * to you under the Apache License, Version 2.0 (the
012 * "License"); you may not use this file except in compliance
013 * with the License. You may obtain a copy of the License at
014 *
015 * http://www.apache.org/licenses/LICENSE-2.0
016 *
017 * Unless required by applicable law or agreed to in writing,
018 * software distributed under the License is distributed on an
019 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
020 * KIND, either express or implied. See the License for the
021 * specific language governing permissions and limitations
022 * under the License.
023 * ====================================================================
024 *
025 * This software consists of voluntary contributions made by many
026 * individuals on behalf of the Apache Software Foundation. For more
027 * information on the Apache Software Foundation, please see
028 * <http://www.apache.org/>.
029 *
030 */
031
032 package org.apache.commons.ssl;
033
034 import javax.net.ServerSocketFactory;
035 import javax.net.SocketFactory;
036 import javax.net.ssl.SSLException;
037 import javax.net.ssl.SSLPeerUnverifiedException;
038 import javax.net.ssl.SSLProtocolException;
039 import javax.net.ssl.SSLSocket;
040 import java.io.EOFException;
041 import java.io.IOException;
042 import java.io.InterruptedIOException;
043 import java.net.DatagramSocket;
044 import java.net.InetAddress;
045 import java.net.NetworkInterface;
046 import java.net.ServerSocket;
047 import java.net.Socket;
048 import java.net.SocketException;
049 import java.net.UnknownHostException;
050 import java.rmi.server.RMISocketFactory;
051 import java.security.GeneralSecurityException;
052 import java.security.cert.X509Certificate;
053 import java.util.Arrays;
054 import java.util.Collections;
055 import java.util.Enumeration;
056 import java.util.HashMap;
057 import java.util.Iterator;
058 import java.util.LinkedList;
059 import java.util.Map;
060 import java.util.Set;
061 import java.util.SortedSet;
062 import java.util.TreeMap;
063 import java.util.TreeSet;
064
065
066 /**
067 * An RMISocketFactory ideal for using RMI over SSL. The server secures both
068 * the registry and the remote objects. The client assumes that either both
069 * the registry and the remote objects will use SSL, or both will use
070 * plain-socket. The client is able to auto detect plain-socket registries
071 * and downgrades itself to accomodate those.
072 * <p/>
073 * Unlike most existing RMI over SSL solutions in use (including Java 5's
074 * javax.rmi.ssl.SslRMIClientSocketFactory), this one does proper SSL hostname
075 * verification. From the client perspective this is straighforward. From
076 * the server perspective we introduce a clever trick: we perform an initial
077 * "hostname verification" by trying the current value of
078 * "java.rmi.server.hostname" against our server certificate. If the
079 * "java.rmi.server.hostname" System Property isn't set, we set it ourselves
080 * using the CN value we extract from our server certificate! (Some
081 * complications arise should a wildcard certificate show up, but we try our
082 * best to deal with those).
083 * <p/>
084 * An SSL server cannot be started without a private key. We have defined some
085 * default behaviour for trying to find a private key to use that we believe
086 * is convenient and sensible:
087 * <p/>
088 * If running from inside Tomcat, we try to re-use Tomcat's private key and
089 * certificate chain (assuming Tomcat-SSL on port 8443 is enabled). If this
090 * isn't available, we look for the "javax.net.ssl.keyStore" System property.
091 * Finally, if that isn't available, we look for "~/.keystore" and assume
092 * a password of "changeit".
093 * <p/>
094 * If after all these attempts we still failed to find a private key, the
095 * RMISocketFactoryImpl() constructor will throw an SSLException.
096 *
097 * @author Credit Union Central of British Columbia
098 * @author <a href="http://www.cucbc.com/">www.cucbc.com</a>
099 * @author <a href="mailto:juliusdavies@cucbc.com">juliusdavies@cucbc.com</a>
100 * @since 22-Apr-2005
101 */
102 public class RMISocketFactoryImpl extends RMISocketFactory {
103 public final static String RMI_HOSTNAME_KEY = "java.rmi.server.hostname";
104 private final static LogWrapper log = LogWrapper.getLogger(RMISocketFactoryImpl.class);
105
106 private volatile SocketFactory defaultClient;
107 private volatile ServerSocketFactory sslServer;
108 private volatile String localBindAddress = null;
109 private volatile int anonymousPort = 31099;
110 private Map clientMap = new TreeMap();
111 private Map serverSockets = new HashMap();
112 private final SocketFactory plainClient = SocketFactory.getDefault();
113
114 public RMISocketFactoryImpl() throws GeneralSecurityException, IOException {
115 this(true);
116 }
117
118 /**
119 * @param createDefaultServer If false, then we only set the default
120 * client, and the default server is set to null.
121 * If true, then a default server is also created.
122 * @throws GeneralSecurityException bad things
123 * @throws IOException bad things
124 */
125 public RMISocketFactoryImpl(boolean createDefaultServer)
126 throws GeneralSecurityException, IOException {
127 SSLServer defaultServer = createDefaultServer ? new SSLServer() : null;
128 SSLClient defaultClient = new SSLClient();
129
130 // RMI calls to localhost will not check that host matches CN in
131 // certificate. Hopefully this is acceptable. (The registry server
132 // will followup the registry lookup with the proper DNS name to get
133 // the remote object, anyway).
134 HostnameVerifier verifier = HostnameVerifier.DEFAULT_AND_LOCALHOST;
135 defaultClient.setHostnameVerifier(verifier);
136 if (defaultServer != null) {
137 defaultServer.setHostnameVerifier(verifier);
138 // The RMI server will try to re-use Tomcat's "port 8443" SSL
139 // Certificate if possible.
140 defaultServer.useTomcatSSLMaterial();
141 X509Certificate[] x509 = defaultServer.getAssociatedCertificateChain();
142 if (x509 == null || x509.length < 1) {
143 throw new SSLException("Cannot initialize RMI-SSL Server: no KeyMaterial!");
144 }
145 setServer(defaultServer);
146 }
147 setDefaultClient(defaultClient);
148 }
149
150 public void setServer(ServerSocketFactory f)
151 throws GeneralSecurityException, IOException {
152 this.sslServer = f;
153 if (f instanceof SSLServer) {
154 final HostnameVerifier VERIFIER;
155 VERIFIER = HostnameVerifier.DEFAULT_AND_LOCALHOST;
156
157 final SSLServer ssl = (SSLServer) f;
158 final X509Certificate[] chain = ssl.getAssociatedCertificateChain();
159 String[] cns = Certificates.getCNs(chain[0]);
160 String[] subjectAlts = Certificates.getDNSSubjectAlts(chain[0]);
161 LinkedList names = new LinkedList();
162 if (cns != null && cns.length > 0) {
163 // Only first CN is used. Not going to get into the IE6 nonsense
164 // where all CN values are used.
165 names.add(cns[0]);
166 }
167 if (subjectAlts != null && subjectAlts.length > 0) {
168 names.addAll(Arrays.asList(subjectAlts));
169 }
170
171 String rmiHostName = System.getProperty(RMI_HOSTNAME_KEY);
172 // If "java.rmi.server.hostname" is already set, don't mess with it.
173 // But blowup if it's not going to work with our SSL Server
174 // Certificate!
175 if (rmiHostName != null) {
176 try {
177 VERIFIER.check(rmiHostName, cns, subjectAlts);
178 }
179 catch (SSLException ssle) {
180 String s = ssle.toString();
181 throw new SSLException(RMI_HOSTNAME_KEY + " of " + rmiHostName + " conflicts with SSL Server Certificate: " + s);
182 }
183 } else {
184 // If SSL Cert only contains one non-wild name, just use that and
185 // hope for the best.
186 boolean hopingForBest = false;
187 if (names.size() == 1) {
188 String name = (String) names.get(0);
189 if (!name.startsWith("*")) {
190 System.setProperty(RMI_HOSTNAME_KEY, name);
191 log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found in my SSL Server Certificate.");
192 hopingForBest = true;
193 }
194 }
195 if (!hopingForBest) {
196 // Help me, Obi-Wan Kenobi; you're my only hope. All we can
197 // do now is grab our internet-facing addresses, reverse-lookup
198 // on them, and hope that one of them validates against our
199 // server cert.
200 Set s = getMyInternetFacingIPs();
201 Iterator it = s.iterator();
202 while (it.hasNext()) {
203 String name = (String) it.next();
204 try {
205 VERIFIER.check(name, cns, subjectAlts);
206 System.setProperty(RMI_HOSTNAME_KEY, name);
207 log.warn("commons-ssl '" + RMI_HOSTNAME_KEY + "' set to '" + name + "' as found by reverse-dns against my own IP.");
208 hopingForBest = true;
209 break;
210 }
211 catch (SSLException ssle) {
212 // next!
213 }
214 }
215 }
216 if (!hopingForBest) {
217 throw new SSLException("'" + RMI_HOSTNAME_KEY + "' not present. Must work with my SSL Server Certificate's CN field: " + names);
218 }
219 }
220 }
221 trustOurself();
222 }
223
224 public void setLocalBindAddress(String localBindAddress) {
225 this.localBindAddress = localBindAddress;
226 }
227
228 public void setAnonymousPort(int port) {
229 this.anonymousPort = port;
230 }
231
232 public void setDefaultClient(SocketFactory f)
233 throws GeneralSecurityException, IOException {
234 this.defaultClient = f;
235 trustOurself();
236 }
237
238 public void setClient(String host, SocketFactory f)
239 throws GeneralSecurityException, IOException {
240 if (f != null && sslServer != null) {
241 boolean clientIsCommonsSSL = f instanceof SSLClient;
242 boolean serverIsCommonsSSL = sslServer instanceof SSLServer;
243 if (clientIsCommonsSSL && serverIsCommonsSSL) {
244 SSLClient c = (SSLClient) f;
245 SSLServer s = (SSLServer) sslServer;
246 trustEachOther(c, s);
247 }
248 }
249 Set names = hostnamePossibilities(host);
250 Iterator it = names.iterator();
251 synchronized (this) {
252 while (it.hasNext()) {
253 clientMap.put(it.next(), f);
254 }
255 }
256 }
257
258 public void removeClient(String host) {
259 Set names = hostnamePossibilities(host);
260 Iterator it = names.iterator();
261 synchronized (this) {
262 while (it.hasNext()) {
263 clientMap.remove(it.next());
264 }
265 }
266 }
267
268 public synchronized void removeClient(SocketFactory sf) {
269 Iterator it = clientMap.entrySet().iterator();
270 while (it.hasNext()) {
271 Map.Entry entry = (Map.Entry) it.next();
272 Object o = entry.getValue();
273 if (sf.equals(o)) {
274 it.remove();
275 }
276 }
277 }
278
279 private Set hostnamePossibilities(String host) {
280 host = host != null ? host.toLowerCase().trim() : "";
281 if ("".equals(host)) {
282 return Collections.EMPTY_SET;
283 }
284 TreeSet names = new TreeSet();
285 names.add(host);
286 InetAddress[] addresses;
287 try {
288 // If they gave us "hostname.com", this will give us the various
289 // IP addresses:
290 addresses = InetAddress.getAllByName(host);
291 for (int i = 0; i < addresses.length; i++) {
292 String name1 = addresses[i].getHostName();
293 String name2 = addresses[i].getHostAddress();
294 names.add(name1.trim().toLowerCase());
295 names.add(name2.trim().toLowerCase());
296 }
297 }
298 catch (UnknownHostException uhe) {
299 /* oh well, nothing found, nothing to add for this client */
300 }
301
302 try {
303 host = InetAddress.getByName(host).getHostAddress();
304
305 // If they gave us "1.2.3.4", this will hopefully give us
306 // "hostname.com" so that we can then try and find any other
307 // IP addresses associated with that name.
308 host = InetAddress.getByName(host).getHostName();
309 names.add(host.trim().toLowerCase());
310 addresses = InetAddress.getAllByName(host);
311 for (int i = 0; i < addresses.length; i++) {
312 String name1 = addresses[i].getHostName();
313 String name2 = addresses[i].getHostAddress();
314 names.add(name1.trim().toLowerCase());
315 names.add(name2.trim().toLowerCase());
316 }
317 }
318 catch (UnknownHostException uhe) {
319 /* oh well, nothing found, nothing to add for this client */
320 }
321 return names;
322 }
323
324 private void trustOurself()
325 throws GeneralSecurityException, IOException {
326 if (defaultClient == null || sslServer == null) {
327 return;
328 }
329 boolean clientIsCommonsSSL = defaultClient instanceof SSLClient;
330 boolean serverIsCommonsSSL = sslServer instanceof SSLServer;
331 if (clientIsCommonsSSL && serverIsCommonsSSL) {
332 SSLClient c = (SSLClient) defaultClient;
333 SSLServer s = (SSLServer) sslServer;
334 trustEachOther(c, s);
335 }
336 }
337
338 private void trustEachOther(SSLClient client, SSLServer server)
339 throws GeneralSecurityException, IOException {
340 if (client != null && server != null) {
341 // Our own client should trust our own server.
342 X509Certificate[] certs = server.getAssociatedCertificateChain();
343 if (certs != null && certs[0] != null) {
344 TrustMaterial tm = new TrustMaterial(certs[0]);
345 client.addTrustMaterial(tm);
346 }
347
348 // Our own server should trust our own client.
349 certs = client.getAssociatedCertificateChain();
350 if (certs != null && certs[0] != null) {
351 TrustMaterial tm = new TrustMaterial(certs[0]);
352 server.addTrustMaterial(tm);
353 }
354 }
355 }
356
357 public ServerSocketFactory getServer() { return sslServer; }
358
359 public SocketFactory getDefaultClient() { return defaultClient; }
360
361 public synchronized SocketFactory getClient(String host) {
362 host = host != null ? host.trim().toLowerCase() : "";
363 return (SocketFactory) clientMap.get(host);
364 }
365
366 public synchronized ServerSocket createServerSocket(int port)
367 throws IOException {
368 // Re-use existing ServerSocket if possible.
369 if (port == 0) {
370 port = anonymousPort;
371 }
372 Integer key = new Integer(port);
373 ServerSocket ss = (ServerSocket) serverSockets.get(key);
374 if (ss == null || ss.isClosed()) {
375 if (ss != null && ss.isClosed()) {
376 System.out.println("found closed server on port: " + port);
377 }
378 log.debug("commons-ssl RMI server-socket: listening on port " + port);
379 ss = sslServer.createServerSocket(port);
380 serverSockets.put(key, ss);
381 }
382 return ss;
383 }
384
385 public Socket createSocket(String host, int port)
386 throws IOException {
387 host = host != null ? host.trim().toLowerCase() : "";
388 InetAddress local = null;
389 String bindAddress = localBindAddress;
390 if (bindAddress == null) {
391 bindAddress = System.getProperty(RMI_HOSTNAME_KEY);
392 if (bindAddress != null) {
393 local = InetAddress.getByName(bindAddress);
394 if (!local.isLoopbackAddress()) {
395 String ip = local.getHostAddress();
396 Set myInternetIps = getMyInternetFacingIPs();
397 if (!myInternetIps.contains(ip)) {
398 log.warn("Cannot bind to " + ip + " since it doesn't exist on this machine.");
399 // Not going to be able to bind as this. Our RMI_HOSTNAME_KEY
400 // must be set to some kind of proxy in front of us. So we
401 // still want to use it, but we can't bind to it.
402 local = null;
403 bindAddress = null;
404 }
405 }
406 }
407 }
408 if (bindAddress == null) {
409 // Our last resort - let's make sure we at least use something that's
410 // internet facing!
411 bindAddress = getMyDefaultIP();
412 }
413 if (local == null && bindAddress != null) {
414 local = InetAddress.getByName(bindAddress);
415 localBindAddress = local.getHostName();
416 }
417
418 SocketFactory sf;
419 synchronized (this) {
420 sf = (SocketFactory) clientMap.get(host);
421 }
422 if (sf == null) {
423 sf = defaultClient;
424 }
425
426 Socket s = null;
427 SSLSocket ssl = null;
428 int soTimeout = Integer.MIN_VALUE;
429 IOException reasonForPlainSocket = null;
430 boolean tryPlain = false;
431 try {
432 s = sf.createSocket(host, port, local, 0);
433 soTimeout = s.getSoTimeout();
434 if (!(s instanceof SSLSocket)) {
435 // Someone called setClient() or setDefaultClient() and passed in
436 // a plain socket factory. Okay, nothing to see, move along.
437 return s;
438 } else {
439 ssl = (SSLSocket) s;
440 }
441
442 // If we don't get the peer certs in 15 seconds, revert to plain
443 // socket.
444 ssl.setSoTimeout(15000);
445 ssl.getSession().getPeerCertificates();
446
447 // Everything worked out okay, so go back to original soTimeout.
448 ssl.setSoTimeout(soTimeout);
449 return ssl;
450 }
451 catch (IOException ioe) {
452 // SSL didn't work. Let's analyze the IOException to see if maybe
453 // we're accidentally attempting to talk to a plain-socket RMI
454 // server.
455 Throwable t = ioe;
456 while (!tryPlain && t != null) {
457 tryPlain = tryPlain || t instanceof EOFException;
458 tryPlain = tryPlain || t instanceof InterruptedIOException;
459 tryPlain = tryPlain || t instanceof SSLProtocolException;
460 t = t.getCause();
461 }
462 if (!tryPlain && ioe instanceof SSLPeerUnverifiedException) {
463 try {
464 if (ssl != null) {
465 ssl.startHandshake();
466 }
467 }
468 catch (IOException ioe2) {
469 // Stacktrace from startHandshake() will be more descriptive
470 // then the one we got from getPeerCertificates().
471 ioe = ioe2;
472 t = ioe2;
473 while (!tryPlain && t != null) {
474 tryPlain = tryPlain || t instanceof EOFException;
475 tryPlain = tryPlain || t instanceof InterruptedIOException;
476 tryPlain = tryPlain || t instanceof SSLProtocolException;
477 t = t.getCause();
478 }
479 }
480 }
481 if (!tryPlain) {
482 log.debug("commons-ssl RMI-SSL failed: " + ioe);
483 throw ioe;
484 } else {
485 reasonForPlainSocket = ioe;
486 }
487 }
488 finally {
489 // Some debug logging:
490 boolean isPlain = tryPlain || (s != null && ssl == null);
491 String socket = isPlain ? "RMI plain-socket " : "RMI ssl-socket ";
492 String localIP = local != null ? local.getHostAddress() : "ANY";
493 StringBuffer buf = new StringBuffer(64);
494 buf.append(socket);
495 buf.append(localIP);
496 buf.append(" --> ");
497 buf.append(host);
498 buf.append(":");
499 buf.append(port);
500 log.debug(buf.toString());
501 }
502
503 // SSL didn't work. Remote server either timed out, or sent EOF, or
504 // there was some kind of SSLProtocolException. (Any other problem
505 // would have caused an IOException to be thrown, so execution wouldn't
506 // have made it this far). Maybe plain socket will work in these three
507 // cases.
508 sf = plainClient;
509 s = JavaImpl.connect(null, sf, host, port, local, 0, 15000, null);
510 if (soTimeout != Integer.MIN_VALUE) {
511 s.setSoTimeout(soTimeout);
512 }
513
514 try {
515 // Plain socket worked! Let's remember that for next time an RMI call
516 // against this host happens.
517 setClient(host, plainClient);
518 String msg = "RMI downgrading from SSL to plain-socket for " + host + " because of " + reasonForPlainSocket;
519 log.warn(msg, reasonForPlainSocket);
520 }
521 catch (GeneralSecurityException gse) {
522 throw new RuntimeException("can't happen because we're using plain socket", gse);
523 // won't happen because we're using plain socket, not SSL.
524 }
525
526 return s;
527 }
528
529
530 public static String getMyDefaultIP() {
531 String anInternetIP = "64.111.122.211";
532 String ip = null;
533 try {
534 DatagramSocket dg = new DatagramSocket();
535 dg.setSoTimeout(250);
536 // 64.111.122.211 is juliusdavies.ca.
537 // This code doesn't actually send any packets (so no firewalls can
538 // get in the way). It's just a neat trick for getting our
539 // internet-facing interface card.
540 InetAddress addr = InetAddress.getByName(anInternetIP);
541 dg.connect(addr, 12345);
542 InetAddress localAddr = dg.getLocalAddress();
543 ip = localAddr.getHostAddress();
544 // log.debug( "Using bogus UDP socket (" + anInternetIP + ":12345), I think my IP address is: " + ip );
545 dg.close();
546 if (localAddr.isLoopbackAddress() || "0.0.0.0".equals(ip)) {
547 ip = null;
548 }
549 }
550 catch (IOException ioe) {
551 log.debug("Bogus UDP didn't work: " + ioe);
552 }
553 return ip;
554 }
555
556 public static SortedSet getMyInternetFacingIPs() throws SocketException {
557 TreeSet set = new TreeSet();
558 Enumeration en = NetworkInterface.getNetworkInterfaces();
559 while (en.hasMoreElements()) {
560 NetworkInterface ni = (NetworkInterface) en.nextElement();
561 Enumeration en2 = ni.getInetAddresses();
562 while (en2.hasMoreElements()) {
563 InetAddress addr = (InetAddress) en2.nextElement();
564 if (!addr.isLoopbackAddress()) {
565 String ip = addr.getHostAddress();
566 String reverse = addr.getHostName();
567 // IP:
568 set.add(ip);
569 // Reverse-Lookup:
570 set.add(reverse);
571
572 }
573 }
574 }
575 return set;
576 }
577
578 }