#
# Copyright (c) 2021 Contributors to the Eclipse Foundation
#
# This program and the accompanying materials are made
# available under the terms of the Eclipse Public License 2.0
# which is available at https://www.eclipse.org/legal/epl-2.0/
#
# SPDX-License-Identifier: EPL-2.0
#

import threading, random, asyncio, queue, ctypes, datetime, subprocess, os, io, json, traceback, sys, re, time, argparse
from typing import Dict, Callable, List, Optional, Union, Type
from snakes.nets import PetriNet
if not 'SELF_CONTAINED' in globals():
    from model import Event, EventType, Parameters, Constraint
    from model import Event, Constraint
    from walker import Walker
    nets: Dict[str, Callable[[], PetriNet]]
    constraints: List[Type['Constraint']]


class TestApplication:
    print_log = False
    algorithm = "Random"
    save_file = ""
    running = False
    stopping_or_starting = False
    running_a_rerun = False
    start_time: datetime.datetime
    adapters: None
    interfaces: None
    walker: 'TestApplicationWalker'
    recorder = WalkRecorder()
    debugger = Debugger()

    def start_adapter(self):
        def reader(pipe: io.BytesIO, cb: Callable[[str], None]):
            with pipe:
                for line in iter(pipe.readline, b''):
                    cb(line.decode()[:-1])
            self.stop("Adapter stopped")

        self.adapters = dict()
        adapters_map = json.loads(ADAPTERS.replace("'", "\""))
        self.interfaces = json.loads(INTERFACES.replace("'", "\""))

        for key in adapters_map:
            self.adapters[key] = subprocess.Popen(adapters_map[key], stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
            threading.Thread(target=reader, args=[self.adapters[key].stdout, self.on_stdout]).start()
            threading.Thread(target=reader, args=[self.adapters[key].stderr, self.on_stderr]).start()

    def log(self, line: str):
        if self.print_log:
            print(f"LOG: {line}")

    def send_to_adapter(self, event: 'Event'):
        if self.running:
            key = self.interfaces[event.interface]
            self.log(f"-> {str(event)}")
            assert self.adapters[key].stdin != None
            self.adapters[key].stdin.write((json.dumps(event.to_json()) + "\n").encode())
            self.adapters[key].stdin.flush()

    def on_stdout(self, line: str):
        try:
            jsn = json.loads(line)
            if jsn['kind'] == 'Adapter': 
                if jsn['type'] == 'started':
                    self.log("Adapter started")
                    self.walker.start()
            elif self.running:
                event = Event.from_json(jsn)
                self.log(f"<- {str(event)}")
                self.walker.received_event(event)
        except Exception as e:
            self.stop(f"Error while processing adapter stdout: {repr(e)}")
            traceback.print_exc()

    def on_stderr(self, line: str):
        self.log(f"Adapter: {line}")

    def stop(self, reason: Optional[str], force: bool = False):
        if (self.stopping_or_starting and not force) or not self.running: return
        self.stopping_or_starting = True
        self.log(f"Stopping{f': {reason}' if reason != None else ''}")
        self.walker.stop()
        for key in self.adapters.keys():
            self.adapters[key].kill()
        self.recorder.stop()
        self.stopping_or_starting = False
        self.running = False
        self.running_a_rerun = False
        self.log("Stopped")

    def start_stop(self):
        if self.stopping_or_starting: return
        self.stopping_or_starting = True

        if self.save_file != "":
             self.recorder.saveAs(self.save_file)

        def clear_coverage_files():
            if os.path.exists(os.path.dirname(__file__) + "/state_coverage.txt"):
                os.remove(os.path.dirname(__file__) + "/state_coverage.txt")
            if os.path.exists(os.path.dirname(__file__) + "/event_coverage.txt"):
                os.remove(os.path.dirname(__file__) + "/event_coverage.txt")
            if os.path.exists(os.path.dirname(__file__) + "/transition_coverage.txt"):
                os.remove(os.path.dirname(__file__) + "/transition_coverage.txt") 
        
        def handle():
            if not self.running:
                self.log('Starting...')
                self.start_time = datetime.datetime.now()
                self.recorder.log(self.log)
                self.walker = TestApplicationWalker(nets, constraints, self.send_to_adapter, self.stop, self.algorithm, self.log, self.recorder, self.debugger)
                try:
                    self.start_adapter()
                except Exception as e:
                    self.log(f"Failed to start adapter: '{str(e)}'")
                    self.stopping_or_starting = False
                    return
                self.running = True
                self.log('Started')
                clear_coverage_files()
            else:
                self.stop(None, True)
            self.stopping_or_starting = False

        hdl = threading.Thread(target=handle)
        hdl.daemon = True
        hdl.start()

    def save_coverage(self):
        if not hasattr(self, 'walker'): return
        f= open(os.path.dirname(__file__) + "/state_coverage.txt","w")
        unseen_states = self.walker.walker.all_states.difference(self.walker.walker.seen_states)
        if len(unseen_states) > 0:
            f.write("Uncovered states:\n\n")
            f.write("Component,Port,Interface,State\n")
            self.print_coverage_info(f, unseen_states, 4)
            f.write("\n")
        if len(self.walker.walker.seen_states) > 0:
            f.write("Covered states:\n\n")
            f.write("Component,Port,Interface,State\n")
            self.print_coverage_info(f, self.walker.walker.seen_states, 4)
        f= open(os.path.dirname(__file__) + "/event_coverage.txt","w")
        unseen_events = self.walker.walker.all_events.difference(self.walker.walker.seen_events)
        if len(unseen_events) > 0:
            f.write("Uncovered events:\n\n")
            f.write("Component,Port,Interface,Event\n")
            self.print_coverage_info(f, unseen_events, 4)
            f.write("\n")
        if len(self.walker.walker.seen_events) > 0:
            f.write("Covered events:\n\n")
            f.write("Component,Port,Interface,Event\n")
            self.print_coverage_info(f, self.walker.walker.seen_events, 4)
        f= open(os.path.dirname(__file__) + "/transition_coverage.txt","w")
        unseen_clauses = self.walker.walker.all_clauses.difference(self.walker.walker.seen_clauses)
        if len(unseen_clauses) > 0:
            f.write("Uncovered transition clauses:\n\n")
            f.write("Component,Port,Interface,Clause,SourceLine\n")
            self.print_coverage_info(f, unseen_clauses, 5)
            f.write("\n")
        if len(self.walker.walker.seen_clauses) > 0:
            f.write("Covered transition clauses:\n\n")
            f.write("Component,Port,Interface,Clause,SourceLine\n")
            self.print_coverage_info(f, self.walker.walker.seen_clauses, 5)
        f.close()

    def print_coverage_info(self, f, info, nrColumns):
        ports = dict()
        for s in info:
            fragments = s.split(".")
            port = fragments[2] + "," + fragments[1] + "," + fragments[0]
            if nrColumns == 4:
                content = ".".join(fragments[3:])
            else:
                content = ".".join(fragments[3:len(fragments)-1]) + "," + fragments[-1]
            if ports.get(port) == None:
                ports.update({port: [content]})
            else:
                ports.get(port).append(content)
        keys_list = [k for k in ports.keys()]
        keys_list.sort()
        for k in keys_list:
            ports[k].sort()
            for v in ports[k]:
                f.write(k + "," + v + "\n")


def timeout_condition(app, timeout, actual_timeout):
    if app.print_log:
        print(f"actual_timeout: {actual_timeout} >= timeout: {timeout} result: {actual_timeout >= timeout}") 
    return actual_timeout >= timeout

def state_coverage_condition(app, stateCoverage, actual_stateCoverage):
    if app.print_log:
        print(f"actual_stateCoverage: {actual_stateCoverage} >= stateCoverage: {stateCoverage} result: {actual_stateCoverage >= stateCoverage}")
    return actual_stateCoverage >= stateCoverage

def event_coverage_condition(app, eventCoverage, actual_eventCoverage):
    if app.print_log:
        print(f"actual_eventCoverage: {actual_eventCoverage} >= eventCoverage: {eventCoverage} result: {actual_eventCoverage >= eventCoverage}")
    return actual_eventCoverage >= eventCoverage

def transition_coverage_condition(app, transitionCoverage, actual_transitionCoverage):
    if app.print_log:
        print(f"actual_transitionCoverage: {actual_transitionCoverage} >= transitionCoverage: {transitionCoverage} result: {actual_transitionCoverage >= transitionCoverage}")
    return actual_transitionCoverage >= transitionCoverage

def validate_condition_string(condition):
    timeout_pattern = r'timeout\s*\(\s*\d+\s*\)'
    state_coverage_pattern = r'stateCoverage\s*\(\s*\d+\s*\)'
    event_coverage_pattern = r'eventCoverage\s*\(\s*\d+\s*\)'
    transition_coverage_pattern = r'transitionCoverage\s*\(\s*\d+\s*\)'
    full_pattern = rf'({timeout_pattern}|{state_coverage_pattern}|{event_coverage_pattern}|{transition_coverage_pattern})(\s+(and|or)\s+({timeout_pattern}|{state_coverage_pattern}|{event_coverage_pattern}|{transition_coverage_pattern}))*'
    return re.fullmatch(full_pattern, condition)

def evaluate_condition(left, operator, right):
    if operator == "and":
        return left and right
    elif operator == "or":
        return left or right
    else:
        raise ValueError(f"Unknown operator: {operator}")

def stop_condition(condition, app, timeoutValue):
    result = None
    timeout_pattern = re.compile(r'timeout\s*\(\s*(\d+)\s*\)')
    state_coverage_pattern = re.compile(r'stateCoverage\s*\(\s*(\d+)\s*\)')
    event_coverage_pattern = re.compile(r'eventCoverage\s*\(\s*(\d+)\s*\)')
    transition_coverage_pattern = re.compile(r'transitionCoverage\s*\(\s*(\d+)\s*\)')

    i = 0
    tokens = re.split(r'(\s+and\s+|\s+or\s+)', f'{condition}')
    while i < len(tokens):
        token = tokens[i].strip()
        i += 1

        if token in ["and", "or"]:
            operator = token
            continue
        
        if timeout_pattern.match(token):
            value = int(timeout_pattern.findall(token)[0])
            condition_result = timeout_condition(app, value, timeoutValue)
        elif state_coverage_pattern.match(token):
            value = int(state_coverage_pattern.findall(token)[0])
            current_value = round(len(app.walker.walker.seen_states)*100/len(app.walker.walker.all_states))
            condition_result = state_coverage_condition(app, value, current_value)
        elif event_coverage_pattern.match(token):
            value = int(event_coverage_pattern.findall(token)[0])
            current_value = round(len(app.walker.walker.seen_events)*100/len(app.walker.walker.all_events))
            condition_result = event_coverage_condition(app, value, current_value)
        elif transition_coverage_pattern.match(token):
            value = int(transition_coverage_pattern.findall(token)[0])
            current_value = round(len(app.walker.walker.seen_clauses)*100/len(app.walker.walker.all_clauses))
            condition_result = transition_coverage_condition(app, value, current_value)
        else:
            raise ValueError(f"Unknown condition: {token}")
        
        if result is None:
            result = condition_result
        else:
            result = evaluate_condition(result, operator, condition_result)
    if app.print_log:
        print(f"result: {result}")
    return not result

def rerun(app, args):
    if args.run != "":
        if not args.run.endswith(".recording"):
            print("file does not have a recording extension")
            exit(1)
        if not os.path.isfile(args.run):
            print("file does not exist")
            exit(1)
        app.recorder.playFrom(args.run, False, False)
        app.running_a_rerun = True
    if args.continue_run != "":
        if not args.continue_run.endswith(".recording"):
            print("file does not have a recording extension")
            exit(1)
        if not os.path.isfile(args.continue_run):
            print("file does not exist")
            exit(1)
        app.running_a_rerun = True
        app.recorder.playFrom(args.continue_run, False, True)
    if args.continue_and_save_run != "":
        if not args.continue_and_save_run.endswith(".recording"):
            print("file does not have a recording extension")
            exit(1)
        if not os.path.isfile(args.continue_and_save_run):
            print("file does not exist")
            exit(1)
        app.recorder.playFrom(args.continue_and_save_run, True, True)
        app.running_a_rerun = True

class AllowNewLinesHelpFormatter(argparse.HelpFormatter):
    def _split_lines(self, text, width):
        return text.splitlines()

def stop_condition_reached(walker):
    result = None
    timeout_pattern = re.compile(r'timeout\s*\(\s*(\d+)\s*\)')
    state_coverage_pattern = re.compile(r'stateCoverage\s*\(\s*(\d+)\s*\)')
    event_coverage_pattern = re.compile(r'eventCoverage\s*\(\s*(\d+)\s*\)')
    transition_coverage_pattern = re.compile(r'transitionCoverage\s*\(\s*(\d+)\s*\)')

    i = 0
    tokens = re.split(r'(\s+and\s+|\s+or\s+)', f'{stop_condition_arg}')
    while i < len(tokens):
        token = tokens[i].strip()
        i += 1

        if token in ["and", "or"]:
            operator = token
            continue
        
        if timeout_pattern.match(token):
            value = int(timeout_pattern.findall(token)[0])
            condition_result = timeout >= value
        elif state_coverage_pattern.match(token):
            value = int(state_coverage_pattern.findall(token)[0])
            current_value = round(len(walker.seen_states)*100/len(walker.all_states))
            condition_result = current_value >= value
        elif event_coverage_pattern.match(token):
            value = int(event_coverage_pattern.findall(token)[0])
            current_value = round(len(walker.seen_events)*100/len(walker.all_events))
            condition_result = current_value >= value
        elif transition_coverage_pattern.match(token):
            value = int(transition_coverage_pattern.findall(token)[0])
            current_value = round(len(walker.seen_clauses)*100/len(walker.all_clauses))
            condition_result = current_value >= value
        else:
            raise ValueError(f"Unknown condition: {token}")
        if result is None:
            result = condition_result
        else:
            result = evaluate_condition(result, operator, condition_result)
    return result

stop_condition_arg = None
timeout = 0

if __name__ == "__main__":
    parser=argparse.ArgumentParser(formatter_class=AllowNewLinesHelpFormatter)
    parser.add_argument("-l", "--log", action="store_true", help="Log to console")
    parser.add_argument("-c", "--coverage", action="store_true", help="Save coverage information")
    parser.add_argument('-a', "--algorithm", default="random", choices=["random", "improved", "joker", "scenario-random", "scenario-joker"], help="Choose test algorithm (default: %(default)s)")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-s", "--save", type=str, default="", help="Save recording file")
    group.add_argument("-r", "--run", type=str, default="", help="Recorded run")
    group.add_argument("-cr", "--continue_run", type=str, default="", help="Continue recorded run")
    group.add_argument("-csr", "--continue_and_save_run", type=str, default="", help="Continue and save recorded run")
    parser.add_argument("-sc", "--stop_condition", type=str, default="timeout(10)", help="Provide minimal values for:\n"
                                                                            "    timeout(<value>), where value is a integer in seconds;\n"
                                                                            "    stateCoverage(<value>), where value is a integer in percentage;\n"
                                                                            "    eventCoverage(<value>), where value is a integer in percentage; and/or\n"
                                                                            "    transitionCoverage(<value>), where value is a integer in percentage.\n"
                                                                            "With these functions expressions can be made using \"and\" and \"or\".\n"
                                                                            "(example: -sc \"timeout(10) and eventCoverage(50)\")\n" 
                                                                            "The expressions are left-associative. So\n"
                                                                            "    \"timeout(<value>) and stateCoverage(<value>) or eventCoverage(<value>)\"\n"
                                                                            "should be read as\n"
                                                                            "    \"(timeout(<value>) and stateCoverage(<value>)) or eventCoverage(<value>)\".\n"
                                                                            "(default: %(default)s)")

    args=parser.parse_args()
    stop_condition_arg = args.stop_condition
    if not validate_condition_string(args.stop_condition):
        print("Error parsing stop condition")
        sys.exit(1)
    app = TestApplication()
    app.print_log = args.log
    if args.algorithm == "improved":
        app.algorithm = "Prioritize non-selected"
    elif args.algorithm == "joker":
        app.algorithm = "Joker"
    elif args.algorithm == "scenario-random":
        app.algorithm = "Scenario Random"
    elif args.algorithm == "scenario-joker":
        app.algorithm = "Scenario Joker"
    if args.save != "":
        app.save_file = args.save
        if not args.save.endswith(".recording"):
             app.save_file = args.save + ".recording"
    rerun(app, args)
    app.start_stop()
    print("wait for running ")
    while app.running == False:
        print(".", end="")
    print(".")
    if app.running == True:
        print("App is running")
    app.walker.walker.check_stop_condition = stop_condition_reached
    while stop_condition(args.stop_condition, app, timeout) and app.running == True:
        time.sleep(1)
        timeout = timeout+1
    app.stop(None, True)
    print("Time: " + str(timeout))
    print("Test app steps: " + str(app.walker.walker.test_app_steps))
    print("SUT steps: " + str(app.walker.walker.sut_steps))
    print("SUT events timeouts: " + str(app.walker.walker.sut_events_timeout_counter))
    if hasattr(app.walker.strategy_implementation, "executed_scenarios"):
        print("Executed scenarios: " + " ".join(app.walker.strategy_implementation.executed_scenarios))
    if args.coverage:
        app.save_coverage()
    if stop_condition(args.stop_condition, app, timeout) and app.running == False and args.run == "":
        print("App stopped (1) = fail")
        sys.exit(1)
    print("App stopped (0) = success")
    sys.exit(0)