From 869288b96eedc3a0df22bdecaa842e6901044324 Mon Sep 17 00:00:00 2001
From: Peter Svihra <peter.svihra@cern.ch>
Date: Wed, 3 Jun 2020 11:13:31 +0200
Subject: [PATCH] protocol refactor with proper sequence send/receive counting

- safety checks on continuous packet mismatch
- lowered all sending timeouts to 5 ms
- better ACK/NACK propagation
---
 arduino/common/lib/CommsControl/CommsCommon.h |  10 +-
 .../common/lib/CommsControl/CommsControl.cpp  | 127 ++++++++++--------
 .../common/lib/CommsControl/CommsControl.h    |  13 +-
 raspberry-dataserver/CommsFormat.py           |   9 +-
 raspberry-dataserver/CommsLLI.py              |  63 +++++++--
 5 files changed, 141 insertions(+), 81 deletions(-)

diff --git a/arduino/common/lib/CommsControl/CommsCommon.h b/arduino/common/lib/CommsControl/CommsCommon.h
index 6ecdb880..abb23acf 100644
--- a/arduino/common/lib/CommsControl/CommsCommon.h
+++ b/arduino/common/lib/CommsControl/CommsCommon.h
@@ -3,9 +3,9 @@
 
 #include <Arduino.h>
 
-#define CONST_TIMEOUT_ALARM 5
-#define CONST_TIMEOUT_DATA  10 
-#define CONST_TIMEOUT_CMD   50
+#define COSNT_TIMEOUT_TRANSFER 5
+#define CONST_PACKET_RETRIES 10
+#define CONST_MISMATCH_COUNTER 20
 
 #define PAYLOAD_MAX_SIZE_BUFFER 128
 
@@ -30,8 +30,8 @@
 #define COMMS_CONTROL_SUPERVISORY 0x01
 
 #define COMMS_CONTROL_TYPES 0x0F
-#define COMMS_CONTROL_ACK   0x00 | COMMS_CONTROL_SUPERVISORY
-#define COMMS_CONTROL_NACK  0x04 | COMMS_CONTROL_SUPERVISORY
+#define COMMS_CONTROL_ACK   (0x00 | COMMS_CONTROL_SUPERVISORY)
+#define COMMS_CONTROL_NACK  (0x04 | COMMS_CONTROL_SUPERVISORY)
 
 #define PACKET_TYPE  0xC0
 #define PACKET_ALARM 0xC0
diff --git a/arduino/common/lib/CommsControl/CommsControl.cpp b/arduino/common/lib/CommsControl/CommsControl.cpp
index e8f601da..e0ad8674 100644
--- a/arduino/common/lib/CommsControl/CommsControl.cpp
+++ b/arduino/common/lib/CommsControl/CommsControl.cpp
@@ -15,14 +15,6 @@ CommsControl::CommsControl(uint32_t baudrate) {
     memset(_comms_received, 0, sizeof(_comms_received));
     memset(_comms_send    , 0, sizeof(_comms_send    ));
 
-//    _ring_buff_alarm = RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING>();
-//    _ring_buff_data  = RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING>();
-//    _ring_buff_cmd   = RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING>();
-
-//    _ring_buff_received = RingBuf<Payload, COMMS_MAX_SIZE_RB_RECEIVING>();
-
-//    _comms_tmp   = CommsFormat(COMMS_MAX_SIZE_PACKET - COMMS_MIN_SIZE_PACKET );
-
     CommsFormat::generateACK(_comms_ack);
     CommsFormat::generateNACK(_comms_nck);
 
@@ -41,16 +33,14 @@ void CommsControl::beginSerial() {
 // main function to always call and try and send data
 // _last_trans_time is changed when transmission occurs in sendQueue
 void CommsControl::sender() {
-    if (static_cast<uint32_t>(millis()) - _last_trans_time > CONST_TIMEOUT_ALARM) {
-        sendQueue(&_ring_buff_alarm);
-    }
-
-    if (static_cast<uint32_t>(millis()) - _last_trans_time > CONST_TIMEOUT_CMD) {
-        sendQueue(&_ring_buff_cmd);
-    }
-
-    if (static_cast<uint32_t>(millis()) - _last_trans_time > CONST_TIMEOUT_DATA) {
-        sendQueue(&_ring_buff_data);
+    if (static_cast<uint32_t>(millis()) - _last_trans_time > COSNT_TIMEOUT_TRANSFER) {
+        if (_packet_set) {
+            resendPacket();
+        } else {
+            if      (sendQueue(&_ring_buff_alarm)) { ; }
+            else if (sendQueue(&_ring_buff_cmd  )) { ; }
+            else if (sendQueue(&_ring_buff_data )) { ; }
+        }
     }
 }
 
@@ -69,7 +59,7 @@ void CommsControl::receiver() {
             _last_trans_index += Serial.readBytes(_last_trans + _last_trans_index, 1);
 
             // if managed to read at least 1 byte
-            if (_last_trans_index > 0 && _last_trans_index < COMMS_MAX_SIZE_BUFFER) {
+            if (_last_trans_index > 0 && _last_trans_index < COMMS_MAX_SIZE_BUFFER - 1) {
                 current_trans_index = _last_trans_index - 1;
 
                 // find the boundary of frames
@@ -82,7 +72,7 @@ void CommsControl::receiver() {
                         // if managed to decode and compare CRC
                         if (decoder(_last_trans, _start_trans_index, _last_trans_index)) {
 
-                            _sequence_receive = (*(_comms_tmp.getControl()) >> 1 ) & 0x7F;
+                            uint8_t sequence_received = _comms_tmp.getSequenceReceive();
                             // to decide ACK/NACK/other; for other gain sequenceReceive
                             uint8_t control = *(_comms_tmp.getControl() + 1);
                             uint8_t address = *_comms_tmp.getAddress();
@@ -94,27 +84,31 @@ void CommsControl::receiver() {
                             switch(control & COMMS_CONTROL_TYPES) {
                                 case COMMS_CONTROL_NACK:
                                     // received NACK
-                                    // TODO: modify timeout for next sent frame?
-                                    // resendPacket(&address);
                                     break;
                                 case COMMS_CONTROL_ACK:
                                     // received ACK
-                                    finishPacket(type);
+                                    finishPacket(sequence_received);
                                     break;
                                 default:
-                                    uint8_t sequence_receive = (control >> 1 ) & 0x7F;
-                                    sequence_receive += 1;
                                     // received INFORMATION
-                                    if (receivePacket(type)) {
-                                        _comms_ack.setAddress(&address);
-                                        _comms_ack.setSequenceReceive(sequence_receive);
-                                        sendPacket(_comms_ack);
+                                    uint8_t sequence = _comms_tmp.getSequenceSend();
+                                    CommsFormat * response = &_comms_ack;
+
+                                    // check counters
+                                    if (_sequence_receive != sequence) {
+                                        trackMismatch(sequence);
+                                        response = &_comms_nck;
                                     } else {
-                                        _comms_nck.setAddress(&address);
-                                        _comms_nck.setSequenceReceive(sequence_receive);
-                                        sendPacket(_comms_nck);
+                                        resetReceiver(sequence + 1);
                                     }
 
+                                    // check proper unpacking
+                                    if(!receivePacket(type)) {
+                                        response = &_comms_nck;
+                                    }
+                                    response->setAddress(&address);
+                                    response->setSequenceReceive(_sequence_receive);
+                                    sendPacket(*response);
                                     break;
                             }
                         }
@@ -128,7 +122,7 @@ void CommsControl::receiver() {
                         break;
                     }
                 }
-            } else if (_last_trans_index >= COMMS_MAX_SIZE_BUFFER) {
+            } else if (_last_trans_index >= COMMS_MAX_SIZE_BUFFER - 1) {
                 _last_trans_index = 0;
             }
         }
@@ -215,32 +209,55 @@ bool CommsControl::decoder(uint8_t* data, uint8_t data_start, uint8_t data_stop)
 }
 
 // sending anything of commsDATA format
-void CommsControl::sendQueue(RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue) {
+bool CommsControl::sendQueue(RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue) {
     // if have data to send
     if (!queue->isEmpty()) {
-        queue->operator [](0).setSequenceSend(_sequence_send);
-        sendPacket(queue->operator [](0));
-
-        // reset sending counter
-        _last_trans_time = static_cast<uint32_t>(millis());
+        _packet_set = queue->pop(_packet);
+        if (_packet_set) {
+            _packet.setSequenceSend(_sequence_send);
+            sendPacket(_packet);
+        }
     }
+    return _packet_set;
 }
 
 void CommsControl::sendPacket(CommsFormat &packet) {
+    // reset sending counter
+    _last_trans_time = static_cast<uint32_t>(millis());
+
     // if encoded and able to write data
     if (encoder(packet.getData(), packet.getSize()) ) {
         if (Serial.availableForWrite() >= _comms_send_size) {
             Serial.write(_comms_send, _comms_send_size);
-        } 
+        }
     }
 }
 
-// resending the packet, can lower the timeout since either NACK or wrong FCS already checked
-//WIP
-void CommsControl::resendPacket(RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue) {
-    ;
+void CommsControl::resendPacket() {
+    if ((++_packet_retries) < CONST_PACKET_RETRIES) {
+        sendPacket(_packet);
+    } else {
+        resetPacket();
+    }
+}
+
+void CommsControl::trackMismatch(uint8_t sequence_receive) {
+    if (_mismatch_counter++ > CONST_MISMATCH_COUNTER) {
+        resetReceiver(sequence_receive);
+    }
+}
+
+void CommsControl::resetReceiver(uint8_t sequence_receive) {
+    _mismatch_counter = 0;
+    if (sequence_receive != 0xFF) {
+        _sequence_receive = (sequence_receive & 0x7F);
+    }
 }
 
+void CommsControl::resetPacket() {
+    _packet_set = false;
+    _packet_retries = 0;
+}
 
 // receiving anything of commsFormat
 bool CommsControl::receivePacket(PRIORITY &type) {
@@ -258,25 +275,19 @@ bool CommsControl::receivePacket(PRIORITY &type) {
 }
 
 // if FCS is ok, remove from queue
-void CommsControl::finishPacket(PRIORITY &type) {
-    RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue = getQueue(type);
-
-    if (queue != nullptr && !queue->isEmpty()) {
-        // get the sequence send from first entry in the queue, add one as that should be return
-        // 0x7F to deal with possible overflows (0 should follow after 127)
-        if (((queue->operator [](0).getSequenceSend() + 1) & 0x7F) ==  _sequence_receive) {
-            _sequence_send = (_sequence_send + 1) % 128;
-            CommsFormat comms_rm;
-            if (queue->pop(comms_rm)) {
-                ;
-            }
-        }
+void CommsControl::finishPacket(uint8_t &sequence_received) {
+    // get the sequence send from first entry in the queue, add one as that should be return
+    // 0x7F to deal with possible overflows (0 should follow after 127)
+    uint8_t sequence = ((_sequence_send + 1) & 0x7F);
+    if (sequence == sequence_received) {
+        _sequence_send = sequence;
+        resetPacket();
     }
 }
 
 PRIORITY CommsControl::getInfoType(uint8_t &address) {
     // return enum element corresponding to the address
-    return (PRIORITY)(address & PACKET_TYPE);
+    return static_cast<PRIORITY>(address & PACKET_TYPE);
 }
 
 // get link to queue according to packet format
diff --git a/arduino/common/lib/CommsControl/CommsControl.h b/arduino/common/lib/CommsControl/CommsControl.h
index b623ebb2..39b0875d 100644
--- a/arduino/common/lib/CommsControl/CommsControl.h
+++ b/arduino/common/lib/CommsControl/CommsControl.h
@@ -29,10 +29,13 @@ private:
     RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *getQueue(PRIORITY &type);
     PRIORITY getInfoType(uint8_t &address);
 
-    void sendQueue    (RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue);
-    void resendPacket (RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue);
+    bool sendQueue(RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> *queue);
+    void resendPacket ();
+    void resetPacket  ();
+    void trackMismatch(uint8_t sequence_receive);
+    void resetReceiver(uint8_t sequence_receive = 0xFF);
     bool receivePacket(PRIORITY &type);
-    void finishPacket (PRIORITY &type);
+    void finishPacket (uint8_t &sequence_received);
 
     bool encoder(uint8_t* payload, uint8_t data_size);
     bool decoder(uint8_t* payload, uint8_t dataStart, uint8_t data_stop);
@@ -46,6 +49,10 @@ private:
     CommsFormat _comms_ack;
     CommsFormat _comms_nck;
 
+    uint8_t     _mismatch_counter = 0;
+    CommsFormat _packet;
+    uint8_t     _packet_retries = 0;
+    bool        _packet_set = false;
     RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> _ring_buff_alarm;
     RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> _ring_buff_data;
     RingBuf<CommsFormat, COMMS_MAX_SIZE_RB_SENDING> _ring_buff_cmd;
diff --git a/raspberry-dataserver/CommsFormat.py b/raspberry-dataserver/CommsFormat.py
index 56b9d726..0bf02968 100644
--- a/raspberry-dataserver/CommsFormat.py
+++ b/raspberry-dataserver/CommsFormat.py
@@ -143,6 +143,7 @@ class CommsPacket(object):
     def __init__(self, rawdata):
         self._data = rawdata
         self._sequence_receive = 0
+        self._sequence_receive_ack = 0
         self._address = None
         self._byteArray = None
         self._datavalid = False
@@ -155,6 +156,10 @@ class CommsPacket(object):
     @property
     def sequence_receive(self):
         return self._sequence_receive
+    
+    @property
+    def sequence_send(self):
+        return self._sequence_send
 
     @property
     def address(self):
@@ -186,7 +191,7 @@ class CommsPacket(object):
         tmp_comms.copyBytes(byteArray)
         if tmp_comms.compareCrc():
             control     = tmp_comms.getData()[tmp_comms.getControl()+1]
-            self._sequence_receive = (tmp_comms.getData()[tmp_comms.getControl()] >> 1) & 0x7F
+            self._sequence_receive = tmp_comms.getSequenceReceive()
             self._address = tmp_comms.getData()[1]
             
             # get type of packet
@@ -199,7 +204,7 @@ class CommsPacket(object):
                 self._acked = True
             else:
                 # received data
-                self._sequence_receive = ((control >> 1) & 0x7F) + 1
+                self._sequence_send = tmp_comms.getSequenceSend()
                 self._byteArray = tmp_comms.getData()[tmp_comms.getInformation():tmp_comms.getFcs()]
         else:
             raise CommsChecksumError
diff --git a/raspberry-dataserver/CommsLLI.py b/raspberry-dataserver/CommsLLI.py
index ce7e3f6a..937d21c7 100755
--- a/raspberry-dataserver/CommsLLI.py
+++ b/raspberry-dataserver/CommsLLI.py
@@ -17,6 +17,9 @@ from CommsFormat import CommsPacket, CommsACK, CommsNACK, CommsChecksumError, ge
 logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
 logging.getLogger().setLevel(logging.INFO)
 
+class SequenceReceiveMismatch(Exception):
+    pass
+
 class CommsLLI:
     def __init__(self, loop, throttle=1, file='', number=10000):
         super().__init__()
@@ -46,7 +49,7 @@ class CommsLLI:
         self._dv_data     = asyncio.Event(loop=self._loop)
         # maps between address and queues/events/timeouts 
         self._queues   = {0xC0: self._alarms, 0x80: self._commands, 0x40: self._data}
-        self._timeouts = {0xC0: 10, 0x80: 50, 0x40: 200}
+        self._timeouts = {0xC0: 5, 0x80: 5, 0x40: 5}
         self._acklist  = {0xC0: self._dv_alarms, 0x80: self._dv_commands, 0x40: self._dv_data}
         
         # receive
@@ -58,7 +61,7 @@ class CommsLLI:
 
         # packet counting
         self._sequence_send    = 0
-        self._sequence_receive = 0
+        self.resetReceiver(0)
 
     async def main(self, device, baudrate):
         try:
@@ -79,11 +82,12 @@ class CommsLLI:
             logging.debug("Waiting for Command")
             packet = await queue.get()
             packet.setSequenceSend(self._sequence_send)
-            for send_attempt in range(5):
-                # try to send the packet 5 times
+            for send_attempt in range(10):
+                # try to send the packet 10 times
                 try:
                     await self.sendPacket(packet)
                     await asyncio.wait_for(self._acklist[address].wait(), timeout=self._timeouts[address] / 1000)
+                    self._acklist[address].clear()
                 except asyncio.TimeoutError:
                     pass
                 except Exception:
@@ -140,9 +144,11 @@ class CommsLLI:
     async def sendPacket(self, packet):
         if isinstance(packet, CommsACK):
             # don't log acks
+            logging.debug(f"Sending ACK: {binascii.hexlify(packet.encode())}")
             pass
         elif isinstance(packet, CommsNACK):
-            logging.warning(f"Sending NACK: {binascii.hexlify(packet.encode())}")
+#             logging.debug(f"Sending NACK: {binascii.hexlify(packet.encode())}")
+            pass
         else:
             logging.info(f"Sending {binascii.hexlify(packet.encode())}")
         self._writer.write(packet.encode())
@@ -157,20 +163,33 @@ class CommsLLI:
                 break
         return CommsPacket(bytearray(rawdata))
 
-    def finishPacket(self, address):
-        self._sequence_send = (self._sequence_send + 1) % 128
+    def finishPacket(self, address, sequence_received):
         try:
-            self._queues[address].task_done()
+            sequence = (self._sequence_send + 1) % 128
+            if sequence == sequence_received:
+                self._queues[address].task_done()
+                self._sequence_send = sequence
         except ValueError:
             # task has already been purged from queue
             pass
         else:
             self._acklist[address].set()
+            
+    def resetReceiver(self, sequence_receive = -1):
+        self._mismatch_counter = 0
+        if sequence_receive >= 0:
+            self._sequence_receive = sequence_receive
+            
+    def trackMismatch(self, sequence_receive):
+        if self._mismatch_counter > 20 :
+            self.resetReceiver(sequence_receive)
+            logging.warning(f"Received more than 20 sequence_receive mismatches, resetting")
+        else:
+            self._mismatch_counter += 1
 
     async def recv(self):
         while self._connected:
             packet = await self.readPacket()
-
             try:
                 data = packet.decode()
             except CommsChecksumError:
@@ -183,28 +202,46 @@ class CommsLLI:
                 if packet.acked:
                     logging.info("Received ACK")
                     # increase packet counter
-                    self.finishPacket(packet.address)
+                    self.finishPacket(packet.address, packet.sequence_receive)
                 else:
-                    logging.debug("Received NACK")
+                    logging.info("Received NACK")
             else:
                 # packet should contain valid data
                 try:
                     payload = PayloadFormat.fromByteArray(packet.byteArray)
+                    
+                    self.sequence_receive = packet.sequence_send
+                    self.resetReceiver()
                     self.payloadrecv = payload
-                    logging.debug(f"Received payload type {payload.getType()} for timestamp {payload.timestamp}")
+#                     logging.debug(f"Received payload type {payload.getType()} for timestamp {payload.timestamp}")
                     comms_response = CommsACK(packet.address)
                 except (StructError, ValueError):
                     # invalid payload, but valid checksum - this is bad!
                     logging.error(f"Invalid payload: {payload}")
                     # restart/reflash/swap to redundant microcontroller?
                     comms_response = CommsNACK(packet.address)
+                except SequenceReceiveMismatch:
+#                     logging.debug(f"Mismatch sequence receive, expected: {self.sequence_receive}; received: {packet.sequence_send}")
+                    comms_response = CommsNACK(packet.address)
+                    self.trackMismatch(packet.sequence_send)
                 except HEVVersionError as e:
                     logging.critical(f"HEVVersionError: {e}")
                     exit(1)
                 finally:
-                    comms_response.setSequenceReceive(packet.sequence_receive)
+                    comms_response.setSequenceReceive(self.sequence_receive)
                     await self.sendPacket(comms_response)
 
+    @property
+    def sequence_receive(self):
+        return self._sequence_receive
+    
+    @sequence_receive.setter
+    def sequence_receive(self, sequence):
+        if self._sequence_receive == sequence:
+            self._sequence_receive = (self._sequence_receive + 1) % 128
+        else:
+            raise SequenceReceiveMismatch
+                    
     # callback to dependants to read the received payload
     @property
     def payloadrecv(self):
-- 
GitLab