package com.algobase.share.network;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.IOException;

import java.net.SocketAddress;
import java.net.InetSocketAddress;
import java.net.InetAddress;
import java.net.Inet4Address;
import java.net.Inet6Address;

import java.net.Socket;

public class LedaSocket extends Socket {
	
    //private int buf_sz = 1024;
    private int buf_sz =  1 << 12;

    private int numSizeBytes = 4;

    private int timeout;
    private boolean is_connected;

    protected String error_msg;

    private OutputStream out;
    private InputStream in;
	
    public void progress_init(int sz) {}
    public void progress(int n, int sz){}

    private long readSize()
    { long sz = 0;
      long factor = 1;
      try {
        for (int i=0; i < numSizeBytes; i++)
        { int b = in.read();
          sz +=b * factor;
          factor *= 256;
        }
      } catch (IOException ex) { sz = -1; }
      return sz;
    }

   private void writeSize (long sz) throws IOException 
   { for (int i = 0; i < numSizeBytes; i++)
     { out.write((int)(sz%256));
       sz /= 256;
      }
     out.flush();
    }


   
   private boolean readFile(File file) throws IOException
   { FileOutputStream fileOut = new FileOutputStream(file);     
     byte[] buffer = new byte[buf_sz];
     int sz = (int)readSize();

     if (sz < 0) return false;

     progress_init(sz);

     int count = 0;     
     while (count < sz)
     { int n = in.read(buffer);
       if (n == -1) break; 
       fileOut.write(buffer,0,n); 
       count += n;
       progress(count,sz);
     }

     fileOut.close();
     return count == sz;
   }


   private void writeFile(File file) throws IOException
   { int len = (int)file.length();
     progress_init(len);
     byte bytes[] = new byte[buf_sz];
     FileInputStream fileIn = new FileInputStream(file);
     int count = 0;
     int n = 0;
     writeSize(len);
     while ((n = fileIn.read(bytes)) > 0)
     { out.write(bytes,0,n);  
       out.flush();
       count += n;
       progress(count,len);
     }
     fileIn.close();
   }

   private void writeFiles(File[] files, int num_files) throws IOException
   { 
     byte bytes[] = new byte[buf_sz];

     int len = 0;
     for(int i=0; i<num_files; i++) len += (int)files[i].length();
     writeSize(len);

     progress_init((int)len);

     int count = 0;
     for(int i=0; i<num_files; i++) 
     { FileInputStream fileIn = new FileInputStream(files[i]);
       int n = 0;
       while ((n = fileIn.read(bytes)) > 0)
       { out.write(bytes,0,n);  
         out.flush();
         count += n;
         progress(count,len);
        }
       fileIn.close();
      }
   }


  // public

  public LedaSocket()
  { super();
    System.setProperty("java.net.preferIPv4Stack","true");
    timeout = 5000;
    is_connected = false;
   }

  public void setNumSizeBytes(int num) { numSizeBytes = num; }

  public void setTimeout(int msec) { timeout = msec; }

  //public boolean connected() { return is_connected; }
  public boolean connected() { return !isClosed() && isConnected(); }
  
/*
  public LedaSocket(String address, int port)
  { super();
    timeout = 5000;
    connect(address,port);
  }
*/
  
  
  public boolean connect(String address, int port)
  { error_msg = null;

    if (is_connected) return true;

    System.setProperty("java.net.preferIPv4Stack","true");

    try { 
      //setSoTimeout(timeout); // read timeout

      SocketAddress sock_addr = new InetSocketAddress(address,port);

/*
      InetAddress inet_addr = InetAddress.getByName(address);
      if (inet_addr instanceof Inet6Address) {
         String ip = inet_addr.getHostAddress();
         sock_addr = new InetSocketAddress(ip,port);
      }
      else
        sock_addr = new InetSocketAddress(address,port);
*/
     
      super.connect(sock_addr, timeout);
      in = getInputStream();
      out = getOutputStream();
    } catch (Exception ex) { error_msg = "connect: " + ex.toString(); }

    is_connected = (error_msg == null);
    return is_connected;
  }
  
  public String getError() { return error_msg; }

  public void disconnect() { 
    error_msg = null;
    is_connected = false;
    try { close(); 
    } catch (Exception ex) { error_msg = "disconnect: " + ex.toString(); }
  }

  public boolean writeBytes(byte b[])
  { error_msg = null;
    try { out.write(b);
          out.flush();
    } catch(IOException ex){ error_msg = "writeBytes: " + ex.toString(); }
    return error_msg == null;
   }


  public boolean sendBytes(byte b[]) {
    if (!connected()) return false;
    error_msg = null;
    try { writeSize(b.length);
          out.write(b);
          out.flush();
    } catch(IOException ex){ error_msg = "sendBytes: " + ex.toString(); }
    return error_msg == null;
  }
  
  
  public boolean sendString(String s) {
    return sendBytes(s.getBytes());
  }
  

  public boolean sendInt(int x) {
      return sendString(String.format("%d",x));
  }

  public boolean sendFile(File file)
  { if (!connected()) return false;
    error_msg = null;
    try { writeFile(file);    
    } catch(IOException ex){ error_msg = "sendFile: " + ex.toString(); }
    return error_msg == null;
   }


  public boolean sendFiles(File[] files) {
     return sendFiles(files,files.length);
  }

  public boolean sendFiles(File[] files, int n)
  { if (!connected()) return false;
    error_msg = null;
    try { writeFiles(files,n);    
    } catch(IOException ex){ error_msg = "sendFiles: " + ex.toString(); }
    return error_msg == null;
   }
  


 private byte[] receiveBytes()
 { if (!connected()) return null;

   int sz = (int)readSize();
   if (sz < 0) return null;

   if (sz > (1<<20)) {
      //throw(new RuntimeException("socket: string too long"));
      error_msg = "socket: string too long";
      return null;
   }

   byte[] b = new byte[sz];
   int len = 0;
   try {
      len = in.read(b);
   } catch(IOException ex) { error_msg = "receiveBytes: " + ex.toString(); }

  if (error_msg != null || len < sz) b = null;
  return b;
 }

  
  public String receiveString() { 
    byte[] b = receiveBytes();
    String result = "";
    if (b != null) result = new String(b);
    return result;
  }
  
  public int receiveInt() {
     String s = receiveString();
     int x = 0;
     try {
       x = Integer.parseInt(s);
     } catch(Exception e) {}
     return x;
  }
  
  public boolean wait(String s) { 
    String x = receiveString();
    return x.equals(s);
  }
  
  
  public boolean receiveFile(File file)
  { error_msg = null;
    try { readFile(file);
    } catch(IOException ex){ error_msg = "receiveFile: " + ex.toString(); }
    return error_msg == null;
  }
  
}

