import socket
#import math
#from struct import unpack
#from array import array
import numpy as np
import time

class NovoptelUDP():

    #Parameters
    name = 'UDP'
    ip = '127.0.0.1'
    port = 5024
    sock = None
    debug = False

    
    def __init__(self, ip='127.0.0.1', port=5024, debug=False):
        self.ip = ip
        self.port = port
        self.debug = debug
        self.connect()
        
    def connect(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # UDP
        self.sock.settimeout(0.3)
        dummy, ok = self.udp_read(512+1)
        if ok==0:
            print('Connection timed out.')
            self.sock = None
        else:  
            print('Connected')
        
        
    def close(self):
        self.sock = None
        
    def crc16( self, crc, data):
        
        for x in data:
            crc = ((crc<<1) | (crc>>15)) & 0xFFFF
            crc ^= x
        return crc
        
    def udp_send(self, data: bytes):
        if (self.sock==None):
            self.connect()
        #print(self.port)
        rxok = False
        tries = 0
        while (rxok==False) and (tries<10):
            try:
                self.sock.sendto(bytes(data), (self.ip, self.port))
                rxok = True
            except:
                time.sleep(0.001) # sleep 1 ms
                tries = tries + 1
        
        #print('udp_write: %d tries.' % tries)
        return
    
    def udp_receive(self, numbytes:int):
        received = 0
        ans = bytearray()
        while (received<numbytes):
            try:
                rx, addr = self.sock.recvfrom(numbytes) # buffer size is 1024 bytes
            except:
                #print("Exception in udp_receive!")
                return 0, 0
            ans += rx
            received += len(rx)

        #print('numbytes: ' + str(numbytes))
        #print('received: ' + str(len(ans)))
        #print(type(ans))
        return ans, 1
        
    def udp_read(self, addr: int):
        
        
        #reqBuf16 = array('H',[0x52, addr, 0, 0, 0]) # "R"
        #crc = self.crc16(0xFFFF, reqBuf16)
        #reqBuf16.append(crc)
        
        reqBuf16 = np.array([0x52, addr, 0, 0, 0], dtype=np.uint16)  # "R"
        crc = self.crc16(0xFFFF, reqBuf16)
        reqBuf16 = np.append(reqBuf16, np.uint16(crc))
        
        #print(reqBuf16.tobytes())
        
        ok = 0
        tries = 0
        while (ok==0) and (tries<20):
            tries = tries + 1
            self.udp_send(reqBuf16.tobytes())
            ans, ok = self.udp_receive(8)
        
        if ok==0:
            return 0, 0
                
        #res = unpack('<'+'H'*(len(ans)//2),ans)
        
        res = np.frombuffer(ans, np.uint16)
        
        
        #print(res)
        #crcnew = self.crc16(crc, array('H',[res[2]]))
        crcnew = self.crc16(crc, np.array([res[2]]))
        #print(crcnew)
        
        if ((res[0]==0xaaaa) and (res[1]==crcnew)):
            ok = 1
        else:
            ok = 0
            print('CRC ERROR during udp_read')
            
        return int(res[2]), ok
    
    
    '''
    def udp_readbuffer(self, cmd: bytes, nblocks:int):
        
        rxok = False
        
        buffersize = 512 # addresses
                
        tries = 0
        while (rxok==False) and (tries<10):
            self.udp_write(cmd)
            try:
                
                if False:
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    rbytes = ans
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    rbytes += ans
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    rbytes += ans
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    rbytes += ans
                    data = unpack('<'+'H'*(len(rbytes)//2),rbytes)
                    arr = np.array(list(data))
                    rows = arr.shape[0] // 4
                    #if rows * 4 != arr.shape[0]:
                    #    rows -= 1
                    #    arr = arr[0:rows * 4].reshape(rows, 4)
                    arr = arr[0:rows * 4].reshape(rows, 4)
                
                if True: # this method seems a bit faster than the method above!
                    arr = np.empty([0, 4], dtype=np.uint16)
                    for ii in range(4*nblocks):
                        ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                        if (len(ans)==1024):
                            data = unpack('<'+'H'*(512),ans)
                            data_arr = np.array(list(data))
                            data_arr = data_arr[0:512].reshape(128, 4)
                            arr = np.concatenate([arr,data_arr])

                            
                if False: # this method takes double the time!
                    arr = np.empty([buffersize, 4], dtype=np.uint16)
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    for x in range(128):
                        for y in range(4):
                            ind = x*8 + y*2
                            arr[x,y] = (ans[ind+1]<<8)+ans[ind]
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    for x in range(128):
                        for y in range(4):
                            ind = x*8 + y*2
                            arr[128+x,y] = (ans[ind+1]<<8)+ans[ind]
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    for x in range(128):
                        for y in range(4):
                            ind = x*8 + y*2
                            arr[2*128+x,y] = (ans[ind+1]<<8)+ans[ind]
                    ans, addr = self.sock.recvfrom(1024) # buffer size is 1024 bytes
                    for x in range(128):
                        for y in range(4):
                            ind = x*8 + y*2
                            arr[3*128+x,y] = (ans[ind+1]<<8)+ans[ind]
                            
                rxok = True
            except:
                arr = []
                tries = tries + 1
                print('exception in sock.recvfrom.')

        #print('udp_readbuffer: %d tries.' % tries)
        return arr    
    '''

    
    
    def read(self, addr: int):
        
        res, ok = self.udp_read(addr)
            
        return res
                
        
    '''
    def readbuffer(self, startaddr: int, numaddr:int):
        
        buffersize = 512 # addresses
        
        nblocks = 8
        
        res = np.empty([0, 4], dtype=np.uint16)
        
        while (numaddr>0):
            
            cmd=[0x53] # 'S'
            cmd.append(nblocks-1)
            cmd.append((startaddr>>24)&0xFF)
            cmd.append((startaddr>>16)&0xFF)
            cmd.append((startaddr>>8)&0xFF)
            cmd.append(startaddr&0xFF)
            
            #print(startaddr)
        
            
            block = self.udp_readbuffer(cmd,nblocks)
            
            
            addrinthisblock = min(buffersize*nblocks, numaddr)
            
            res = np.concatenate([res, block[:addrinthisblock,0:]])
            
            numaddr = numaddr - addrinthisblock
            
            startaddr = startaddr + addrinthisblock
        
        return res
    '''
    
        

    def write(self, addr: int, data: int):
        
        #reqBuf16 = array('H',[0x57, addr, data, 0, 0]) # "W"
        #crc = self.crc16(0xFFFF, reqBuf16)
        #reqBuf16.append(crc)
        
        reqBuf16 = np.array([0x57, addr, data, 0, 0], dtype=np.uint16)  # "W"
        crc = self.crc16(0xFFFF, reqBuf16)
        reqBuf16 = np.append(reqBuf16, np.uint16(crc))
        
        
        self.udp_send(reqBuf16.tobytes())
        
        ans, ok = self.udp_receive(8)
        
        if (ok==0):
            res = [0]
        else:
            #res = unpack('>'+'H'*(len(ans)//2),ans)
            res = np.frombuffer(ans, np.uint16)
        
        
        if (res[0]==0xaaaa):
            ok = 1
        else:
            ok = 0
            print('UNKNOWN ERROR during UDP write')
            
        return ok
                
        #print(res)
        #crcnew = self.crc16(crc, array('H',[res[2]]))
        #print(crcnew)
        
        
        
        

    
    


    '''  
    
    def readsdram_sendrequest(self, startaddrseq: int, packetsinthissequence: int, cycles: int):
        
        # set transfer parameters in one command
        
        cmdlist = [ [512 + 105, startaddrseq & 0xFFFF],
                    [512 + 106, int((startaddrseq >> 16) + math.log2(cycles) * 2**12)],
                    [512 + 107, round(packetsinthissequence/cycles)],
                    [512 + 104, 39294]]
        #print(cmdlist)
        #print("cycles: ", cycles)
        
        cmd = []
        for i in range(len(cmdlist)):
            cmd.append(0x57) # 'W'
            for x in range(2):
                cmd.append((cmdlist[i][x]>>8)&0xFF)
                cmd.append(cmdlist[i][x]&0xFF)
                
         #print(cmd)
        self.socket_write(bytes(cmd))
        

      
    def readsdram_getpackets_raw(self, startaddrseq: int, packetsinthissequence: int, cycles: int):
        
        try:
            self.readsdram_sendrequest(startaddrseq, packetsinthissequence, cycles)
        except:
            return b''
        
        #print("packetsinthissequence %d" % (packetsinthissequence))
        debug = False
        
        rx_len = 0
        rxbytes = b''
        while (rx_len<packetsinthissequence):
            newbytes = b''
            try:
                newbytes = self.s.recv(2**14)
                rxbytes += newbytes
                if debug:
                    print("newbytes length: %d" % len(newbytes)) 
                if newbytes==0:
                    return b''
            except:
                #print("newbytes length: %d" % len(newbytes))  
                #print("rxbytes length: %d" % len(rxbytes)) 
                #input("Exception in self.s.recv!")
                ## test communication
                #print("ATE: %d" % self.read(512+1))
                #self.readsdram_sendrequest(startaddrseq, packetsinthissequence, cycles)
                #rxbytes = b''
                #debug = True
                return b''
                
            rx_len = len(rxbytes) / 8
            #print("rx_len %d" % (rx_len))
        
        #print("data length: %d" % len(rxbytes))
        
        #print(rxbytes)
        
        return rxbytes
        
    def readsdram_getpackets(self, startaddrseq: int, packetsinthissequence: int, cycles: int):
        
        packetsreceived = 0
        rxbytes = b''
        while packetsreceived<packetsinthissequence:
            try:
                rxbytes = self.readsdram_getpackets_raw(startaddrseq, packetsinthissequence, cycles)
                #print("readsdram_getpackets_raw length: %d" % len(rxbytes))  
                packetsreceived = len(rxbytes) / 8
                if packetsreceived==0:
                    #input("readsdram_getpackets_raw returned 0 bytes!")
                    self.reconnect()  
            except:
                print("Exception in readsdram_getpackets_raw!")
                self.reconnect()  
            
        
        try:
            data = unpack('>'+'H'*(len(rxbytes)//2),rxbytes)  # bytes to uint16
        except:
            print("data length: %d" % len(rxbytes))   
            print("Exception in readsdram_getpackets_raw!")
            
        arr = np.array(list(data))
        rows = arr.shape[0] // 4
        if rows * 4 != arr.shape[0]:
            rows -= 1
            arr = arr[0:rows * 4].reshape(rows, 4)
        arr = arr[0:rows * 4].reshape(rows, 4)
        
        return arr
    ''' 
        
        
    def readsdram_raw(self, startaddr: int, numaddr: int):
        
        numaddr = min(256*512, numaddr)
        numblocks = np.ceil(numaddr/512)
        
        
        #reqBuf16 = array('H',[0x53, startaddr>>16, startaddr & 0xFFFF, numblocks - 1, 0]) # "S"
        #crc = self.crc16(0xFFFF, reqBuf16)
        #reqBuf16.append(crc)
        
        reqBuf16 = np.array([0x53, startaddr>>16, startaddr & 0xFFFF, numblocks - 1, 0], dtype=np.uint16)   # "S"
        crc = self.crc16(0xFFFF, reqBuf16)
        reqBuf16 = np.append(reqBuf16, np.uint16(crc))
        
        ok = 0
        tries = 0
        while ((ok<1) and (tries<5)):
            self.udp_send(bytes(reqBuf16))
        
        
            ans, ok = self.udp_receive(8*numaddr)
            
            if (ok==0):
                print('Timeouot in readsdram_raw')
                tries = tries + 1
            
        if (ok==0):
            return 0, 0
        else:
            return ans, 1
    
    
    
    def readsdram(self, startaddr: int, numaddr: int, normalization=-1):
        # normalization = -1: all as uint16 (for adc data or test counter)
        # normalization =  0: as uint16 (only Stokes 0) and int16 (Stokes 1 to 3)
        # normalization =  1: all as float32, Stokes1 to 3 normalized to 1
        
        bufsizeaddr = 2**13;
        
        addrtransferred = 0
        
        
        ansarr = bytearray()
        
        while (addrtransferred<numaddr):
            
            numaddrhere   = min(bufsizeaddr, numaddr-addrtransferred)
            startaddrhere = startaddr+addrtransferred
            
            ok = 0
            tries = 0
            while (ok==0) and (tries<5):
                ans, ok = self.readsdram_raw(startaddrhere, numaddrhere)
            
            if (ok>0):
                ansarr.extend(ans)
                addrtransferred = addrtransferred + numaddrhere
            else:
                return np.empty((0, 4))
        
        #open("bla.bin", "wb").write(ansarr)
        
        #ans16 = unpack('<'+'H'*(len(ansarr)//2),ansarr)
        #arr = np.asarray(ans16)
        
        arr = np.frombuffer(ansarr, np.uint16)
        
        rows = arr.shape[0] // 4
        if rows * 4 != arr.shape[0]:
            rows -= 1
            
        data = arr[0:rows * 4].reshape(rows, 4)
        
        dout0 = data[0:,0].astype(np.uint16)

        if normalization==0:
            dout1 = data[0:,1].astype(np.int32)-2**15
            dout2 = data[0:,2].astype(np.int32)-2**15
            dout3 = data[0:,3].astype(np.int32)-2**15
            dout1 = dout1.astype(np.int16)
            dout2 = dout2.astype(np.int16)
            dout3 = dout3.astype(np.int16)
        elif normalization==1:
            dout1 = data[0:,1].astype(np.int32)-2**15
            dout2 = data[0:,2].astype(np.int32)-2**15
            dout3 = data[0:,3].astype(np.int32)-2**15
            dout1 = dout1.astype(np.float32)/2**15
            dout2 = dout2.astype(np.float32)/2**15
            dout3 = dout3.astype(np.float32)/2**15
        else:
            dout1 = data[0:,1].astype(np.uint16)
            dout2 = data[0:,2].astype(np.uint16)
            dout3 = data[0:,3].astype(np.uint16)

        return dout0, dout1, dout2, dout3
    
    def gethistogram(self, histtype=0):
        # histtype 0: sop speed histogram
        # histtype 1: power histogram
        if histtype==0:
            datareg = 512+231
        else:
            datareg = 512+229

        data = np.empty(1024, dtype=np.uint64)
        for ii in range(1024):
            self.write(512 + 230, 2**15 + ii)
            word1 = np.uint64(self.read(datareg))
            self.write(512 + 230, 2**14 + ii)
            word2 = np.uint64(self.read(datareg))
            self.write(512 + 230, ii)
            word3 = np.uint64(self.read(datareg))

            data[ii] = word1 * 2**32 + word2 * 2**16 + word3

        return data        
            
