package com.algobase.share.network;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.IOException;

import java.net.SocketAddress;
import java.net.InetSocketAddress;
import java.net.Socket;

import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;

import java.security.KeyStore;
import java.security.cert.X509Certificate;


public class SSL_Socket {

    SSLContext sslContext;

    SSLSocketFactory factory;
    SSLSocket sock;
	
    int connect_timeout;
    int receive_timeout;

    String error_msg;
    OutputStream out;
    InputStream in;
	
   
  public SSL_Socket()
  { try { init_ssl_context(); } catch(Exception ex) {}
    factory = sslContext.getSocketFactory();
    sock = null;
    connect_timeout = 5000;
    receive_timeout = 0;
   }

  public SSL_Socket(String host, int port)
  { try { init_ssl_context(); } catch(Exception ex) {}
    factory = sslContext.getSocketFactory();
    sock = null;
    connect_timeout = 5000;
    receive_timeout = 0;
    connect(host,port);
  }

  public void init_ssl_context() throws Exception
  {
    sslContext = SSLContext.getInstance("TLS");

    // custom TrustManager (trusts all servers)

    TrustManager tm = new X509TrustManager() {
       @Override
       public void checkClientTrusted(X509Certificate[] chain, String authType)
       {}

       @Override
       public void checkServerTrusted(X509Certificate[] chain, String authType)
       {}
 
       @Override
       public X509Certificate[] getAcceptedIssuers() {
         return null;
       }
    };

    // initialize key manager factory with the client certificate
    KeyManagerFactory keyManagerFactory = 
        KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());

    KeyStore keystore = null;
    String passwd = "mypassword";

    keyManagerFactory.init(keystore,passwd.toCharArray());

    TrustManager[] trust_mgrs = new TrustManager[1];
    trust_mgrs[0] = tm;
    
    sslContext.init(keyManagerFactory.getKeyManagers(), trust_mgrs, null);
  }

  

  public void setReceiveTimeout(int msec) { 
      receive_timeout = msec;
      try { 
        if (sock != null) sock.setSoTimeout(receive_timeout);
      } catch (Exception ex) { error_msg = ex.toString(); }
   }

  public void setConnectTimeout(int msec) { connect_timeout = msec; }

  public boolean connected() { return sock != null && sock.isConnected(); }
  
  public boolean connect(String host, int port)
  { error_msg = null;
    try { 
      sock = (SSLSocket)factory.createSocket();
      sock.setSoTimeout(receive_timeout);
      SocketAddress addr = new InetSocketAddress(host,port);
      sock.connect(addr, connect_timeout);
      in = sock.getInputStream();
      out = sock.getOutputStream();
    } catch (Exception ex) { error_msg = ex.toString(); }

    return error_msg == null;
  }
  
  public String getError() { return error_msg; }

  public boolean disconnect() { 
    error_msg = null;
    try { sock.close(); 
    } catch (Exception ex) { error_msg = ex.toString(); }
    sock = null;
    return error_msg == null;
  }

  public boolean sendBytes(byte b[])
  { error_msg = null;
    try { out.write(b);
          out.flush();
    } catch(IOException ex){ error_msg = ex.toString(); }
    return error_msg == null;
   }

  public boolean sendString(String s) {
    return sendBytes(s.getBytes());
  }


  public int receiveBytes(byte buf[])
  { if (!connected()) return -1;
    error_msg = null;
    int n = 0;
    try {
       n = in.read(buf);
    } catch(IOException ex) { error_msg = ex.toString(); }

   if (error_msg != null) return -1;
   return n;
  }
 
  public int receiveByte() { 
    int x = -1;
    try { x = in.read(); } catch(Exception ex) {}
    return x;
  }

  public String receiveString()
  { String result = "";
    byte[] buf = new byte[1024];
    for(;;) {
      int len = receiveBytes(buf);
      if (len <= 0) break;
      result += new String(buf,0,len);
    }
    return result;
  }
  
}

