#include <winsock2.h>
#include <Ipexport.h>
#include <process.h>
#include <stdio.h>
#include <stdint.h>
#include "pingscan.h"
#include "raw_ping.h"
#include "func.h"

void fill_icmp_data(char * icmp_data, uint32_t data_size);
uint16_t checksum(uint16_t *buffer, uint32_t size);
void decode_resp(char *,uint32_t ,struct sockaddr_in *);
uint32_t check_resp(char *buf, uint32_t bytes, uint32_t dst, uint16_t icmp_id);

/* send_echoes()
 * Send 'packets_count' ICMP-echoes to host 'ip_addr'. 
 * 'critical_section' - must be not null.     Pointer to a critical section 
 * in which a meter increase in requests that do not have the answers.
 * 'unfinished_echoes' - pointer to counter of unfinished echoes
 */
uint32_t send_echoes(uint32_t ip_addr, int32_t data_size, int32_t packets_count, 
                        CRITICAL_SECTION * critical_section, uint32_t * unfinished_echoes)
{
    SOCKET raw_sock;
    struct sockaddr_in dest;
    int32_t bytes_read;
    uint32_t timeout = 1000;
    char *icmp_data;
    uint32_t seq_num = 0;
    uint32_t i;
    uint8_t incrementor = 1;
    uint32_t curr_unfinished_count = 0;
    uint32_t last_unfinished_count = 0;
    int32_t diff = 0;
    int32_t packets_sended = 0;
    uint16_t icmp_id = GetCurrentProcessId();
    int32_t bwrote;
    
    if(packets_count < 1) 
    {
        packets_count = 1;
        incrementor = 0;
    }
    
    raw_sock = WSASocket (AF_INET, SOCK_RAW, IPPROTO_ICMP, NULL, 0,0);
    if (raw_sock == INVALID_SOCKET) 
    {
        fprintf(stderr,"WSASocket() failed: %d\n", WSAGetLastError());
        return 0;
    }
    bytes_read = setsockopt(raw_sock, SOL_SOCKET, SO_SNDTIMEO,(char*)&timeout, sizeof(timeout)); /* Set SND timeout for socket */
    if(bytes_read == SOCKET_ERROR) 
    {
        fprintf(stderr,"Failed to set send timeout: %d\n", WSAGetLastError());
        return 0;
    }
    memset(&dest, 0, sizeof(dest));
    dest.sin_addr.s_addr = ip_addr;
    dest.sin_family = AF_INET;
    if(data_size < DEF_PACKET_SIZE)
        data_size = DEF_PACKET_SIZE;    /* ICMP data size */
    data_size += sizeof(icmp_header);    /* + size of ICMP header */
    icmp_data = (char*)xmalloc(MAX_PACKET);
    if (!icmp_data)
    {
        fprintf(stderr,"HeapAlloc failed %ld\n", GetLastError());
        return 0;
    }
    memset(icmp_data, 0, MAX_PACKET);
    fill_icmp_data(icmp_data, data_size);
    ((icmp_header*)icmp_data)->i_id = icmp_id;

    /* Sending ICMP packets */
    for(;packets_sended < packets_count; packets_sended += incrementor)
    {
        ((icmp_header*)icmp_data)->i_cksum = 0;
        ((icmp_header*)icmp_data)->timestamp = GetTickCount();
        ((icmp_header*)icmp_data)->i_seq = htons(seq_num);
        ((icmp_header*)icmp_data)->i_cksum = checksum((uint16_t*)icmp_data, data_size);
        bwrote = sendto(raw_sock, icmp_data, data_size, 0, (struct sockaddr*)&dest, sizeof(dest));
        if (bwrote == SOCKET_ERROR)
        {
            if (WSAGetLastError() == WSAETIMEDOUT) 
                printf("Timed out\n");
            fprintf(stderr,"sendto failed: %d\n", WSAGetLastError());
            return packets_sended;
        }
        if (bwrote < data_size )
            fprintf(stdout,"Wrote %d bytes\n",bwrote);
        ++seq_num;
        ++total_sended;    /* Increment global variable. Needed for print statistic on interrupt */
        safe_inc_dec(critical_section, unfinished_echoes, 1);
        /* Calculate packets loss */
        curr_unfinished_count = *unfinished_echoes;
        diff = curr_unfinished_count - last_unfinished_count;
        /* Print loss packets */
        if( diff > 10)
        {
            printf(".");
            last_unfinished_count = curr_unfinished_count;
        }
        else if(diff < -10)
            printf("\b \b");

        for(i = 0; i < 0x130; ++i)
            Sleep(0);
    }
    return packets_sended;
}

/* recv_echoes()
 * Receive 'packets_count' ICMP-echoes to host 'ip_addr'. 
 * 'critical_section' - must be not null.     Pointer to a critical section 
 * in which a reduction of meter requests that do not have the answers.
 * 'unfinished_echoes' - pointer to counter of unfinished echoes
 */
uint32_t recv_echoes(uint32_t ip_addr, int32_t data_size, int32_t packets_count, 
                        CRITICAL_SECTION * critical_section, uint32_t * unfinished_echoes)
{
    SOCKET raw_sock;
    struct sockaddr_in recv_from;
    int32_t bytes_read;
    int32_t from_len = sizeof(recv_from);
    uint8_t *recvbuf;
    uint8_t incrementor = 1;
    int32_t recv_packets_count = 0;
    uint16_t icmp_id = GetCurrentProcessId();
    
    if(packets_count < 1)
    {
        packets_count = 1;
        incrementor = 0;
    }
    
    raw_sock = WSASocket (AF_INET, SOCK_RAW, IPPROTO_ICMP, NULL, 0, 0);
    if (raw_sock == INVALID_SOCKET)
    {
        fprintf(stderr,"WSASocket() failed: %d\n", WSAGetLastError());
        return 0;
    }
    
    if(data_size < DEF_PACKET_SIZE)
        data_size = DEF_PACKET_SIZE;
    data_size += sizeof(icmp_header); 
    recvbuf = (char*)xmalloc(MAX_PACKET);
    if (!recvbuf)
    {
        fprintf(stderr,"HeapAlloc failed %ld\n", GetLastError());
        return 0;
    }
    struct sockaddr_in socket_listen_address;
    socket_listen_address.sin_family = AF_INET;
    socket_listen_address.sin_addr.s_addr = htonl(INADDR_ANY);
    socket_listen_address.sin_port = htons(0);
    if(bind(raw_sock, (SOCKADDR*) &socket_listen_address, sizeof(socket_listen_address)) == SOCKET_ERROR)
    {
        fprintf(stderr,"bind failed: %d\n", WSAGetLastError());
        ExitProcess(STATUS_FAILED);        
    }
    for(;recv_packets_count < packets_count; recv_packets_count += incrementor)
    {
        bytes_read = recvfrom(raw_sock, recvbuf, MAX_PACKET, 0, (struct sockaddr*)&recv_from, &from_len);
        if (bytes_read == SOCKET_ERROR)
        {
            if (WSAGetLastError() == WSAETIMEDOUT) 
            {
                  printf("timed out\n");
            }
            fprintf(stderr,"recvfrom failed: %d\n", WSAGetLastError());
            fprintf(stderr, "Aborting programm.\n");
            ExitProcess(STATUS_FAILED);        
        }
        if(check_resp(recvbuf, bytes_read, ip_addr, icmp_id))
        {
            recv_packets_count -= incrementor;
            continue;
        }
        safe_inc_dec(critical_section, unfinished_echoes, 0);
    }
    return recv_packets_count;
}

/* recving_echoes_threads()
 * Wrapper function for running  recv_echoes() in separate thread 
 */
unsigned __stdcall recving_echoes_threads( void* arg )
{
    struct recving_echoes_params * recving_params = (struct recving_echoes_params*) arg;
    recv_echoes(recving_params -> dst_ip, 
                recving_params -> common_params -> packet_size, 
                recving_params -> common_params -> packets_count,
                recving_params -> critical_section, recving_params->unfinished_echoes);
    
    _endthreadex(0);
    return 0;    
}

/* check_resp()
 * Check received packet.
 * Function return 0 if success or not 0 if invalid packet was received 
 */
uint32_t check_resp(char *buf, uint32_t bytes, uint32_t dst, uint16_t icmp_id) 
{
    ip_header *iphdr;
    icmp_header *icmphdr;
    uint16_t iphdrlen;
    iphdr = (ip_header *)buf;
    iphdrlen = iphdr->h_len * 4 ; /* number of 32-bit words *4 = bytes */
    if (bytes  < iphdrlen + ICMP_MIN_SIZE) 
        return 1;
    icmphdr = (icmp_header*)(buf + iphdrlen);
    if (icmphdr->i_type != ICMP_ECHOREPLY) 
        return 2;
    if (icmphdr->i_id != icmp_id) 
        return 3;
    if(iphdr -> source_ip != dst)
        return 4;
    return 0;
}

/* checksum()
 * Calculating ICMP checksum.
 */
uint16_t checksum(uint16_t *buffer, uint32_t size) 
{
    unsigned long cksum=0;
    while(size >1) 
    {
        cksum+=*buffer++;
        size -=sizeof(USHORT);
    }
    if(size ) 
    {
        cksum += *(UCHAR*)buffer;
    }
    cksum = (cksum >> 16) + (cksum & 0xffff);
    cksum += (cksum >>16);
    return (uint16_t)(~cksum);
}

/* fill_icmp_data()
 * Filling ICMP data in packet.
 */
void fill_icmp_data(char * icmp_data, uint32_t data_size)
{
    icmp_header *icmp_hdr;
    char *datapart;
    icmp_hdr = (icmp_header*)icmp_data;
    icmp_hdr->i_type = ICMP_ECHO;
    icmp_hdr->i_code = 0;
    icmp_hdr->i_cksum = 0;
    icmp_hdr->i_seq = 0;
    datapart = icmp_data + sizeof(icmp_header);
    memset(datapart,'Z', data_size - sizeof(icmp_header));
    return;
}