import re

from weevely import utils
from weevely.core.loggers import log
from weevely.core.module import Module
from weevely.core.vectors import PhpCode


class Console(Module):
    """Execute SQL query or run console."""

    def init(self):
        self.register_info({"author": ["Emilio Pinna"], "license": "GPLv3"})

        self.register_vectors(
            [
                PhpCode(
                    """mysqli_report(MYSQLI_REPORT_OFF);if($s=mysqli_connect('${host}:${port}','${user}','${passwd}')){$r=mysqli_query($s,'${query}');if($r){$f=mysqli_fetch_fields($r);foreach($f as $v){echo $v->name.'${linsep}';};echo '${colsep}';while($c=mysqli_fetch_row($r)){echo implode('${linsep}',$c);echo '${linsep}${colsep}';}};echo @mysqli_error($s);@mysqli_close($s);}echo '${errsep}'.@mysqli_connect_error();""",
                    name="mysql",
                ),
                PhpCode(
                    """mysqli_report(MYSQLI_REPORT_OFF);if($s=mysqli_connect('${host}:${port}','${user}','${passwd}','${database}')){$r=mysqli_query($s,'${query}');if($r){$f=mysqli_fetch_fields($r);foreach($f as $v){echo $v->name.'${linsep}';};echo '${colsep}';while($c=mysqli_fetch_row($r)){echo implode('${linsep}',$c);echo '${linsep}${colsep}';}};echo @mysqli_error($s);@mysqli_close($s);}echo '${errsep}'.@mysqli_connect_error();""",
                    name="mysql_database",
                ),
                PhpCode(
                    """mysqli_report(MYSQLI_REPORT_OFF);$r=mysqli_query('${query}');if($r){while($c=mysqli_fetch_row($r)){foreach($c as $key=>$value){echo $value.'${linsep}';}echo '${colsep}';}};mysqli_close();echo '${errsep}'.@mysqli_connect_error().' '.@mysqli_error();""",
                    name="mysql_fallback",
                ),
                PhpCode(
                    """if(pg_connect('host=${host} port=${port} user=${user} password=${passwd}')){$r=pg_query('${query}');if($r){while($c=pg_fetch_row($r)){foreach($c as $key=>$value){echo $value.'${linsep}';}echo '${colsep}';}};pg_close();}echo '${errsep}'.@pg_last_error();""",
                    name="pgsql",
                ),
                PhpCode(
                    """if(pg_connect('host=${host} port=${port} user=${user} dbname=${database} password=${passwd}')){$r=pg_query('${query}');if($r){while($c=pg_fetch_row($r)){foreach($c as $key=>$value){echo $value.'${linsep}';}echo '${colsep}';}};pg_close();}echo '${errsep}'.@pg_last_error();""",
                    name="pgsql_database",
                ),
                PhpCode(
                    """$r=pg_query('${query}');if($r){while($c=pg_fetch_row($r)){foreach($c as $key=>$value){echo $value.'${linsep}';} echo '${colsep}';}};pg_close();echo '${errsep}'.@pg_last_error();""",
                    name="pgsql_fallback",
                ),
            ]
        )

        self.register_arguments(
            [
                {"name": "-user", "help": "SQL username"},
                {"name": "-passwd", "help": "SQL password"},
                {"name": "-host", "help": "Db host (default: localhost)", "nargs": "?", "default": "localhost"},
                {"name": "-dbms", "help": "Db type", "choices": ("mysql", "pgsql"), "default": "mysql"},
                {"name": "-database", "help": "Database name"},
                {"name": "-port", "help": "Port number", "type": int, "default": 0},
                {"name": "-query", "help": "Execute a single query"},
                {"name": "-encoding", "help": "Db text encoding", "default": "utf-8"},
            ]
        )

    def _query(self, vector, args):
        # Randomly generate separators
        colsep = "----%s" % utils.strings.randstr(6).decode("utf-8")
        linsep = "----%s" % utils.strings.randstr(6).decode("utf-8")
        errsep = "----%s" % utils.strings.randstr(6).decode("utf-8")

        args.update({"colsep": colsep, "linsep": linsep, "errsep": errsep})

        # Escape ' in query strings
        self.args["query"] = self.args["query"].replace("\\", "\\\\").replace("'", "\\'")

        # Set default port depending on selected dbms
        if self.args["port"] <= 0:
            self.args["port"] = 5432 if self.args["dbms"] == "pgsql" else 3306

        result = self.vectors.get_result(vector, args)

        # we wan't the result to be unicode, but depending on the source
        # of the data, it could be encoded differently
        try:
            result = str(result)
        except UnicodeError:
            result = str(result.decode(args.get("encoding")))
        # If there is not errstr, something gone really bad (e.g. functions not callable)
        if errsep not in result:
            return {"error": messages.module_sql_console.unexpected_response, "result": []}
        # Split result by errsep
        result, error = result.split(errsep)

        return {
            "error": error,
            "result": [line.split(linsep) for line in result.replace(linsep + colsep, colsep).split(colsep) if line],
        }

    def run(self, **kwargs):
        # The vector name is given by the db type
        vector = self.args.get("dbms")
        encoding = self.args.get("encoding")
        database = self.args.get("database")

        # Check if PostgreSQL and database is given
        if database:
            vector += "_database"
        else:
            # And by the user and password presence
            vector += "" if self.args.get("user") and self.args.get("passwd") else "_fallback"

        # If the query is set, just execute it
        if self.args.get("query"):
            return self._query(vector, self.args)

        # Else, start the console.
        # Check credentials
        self.args["query"] = "SELECT USER;" if vector.startswith("pgsql") else "SELECT USER();"

        result = self._query(vector, self.args)
        if not result["result"]:
            return result

        if result["result"][0]:
            user = result["result"][1][0]

        # Console loop
        while True:
            query = input(f"{user}:{database} SQL> ").strip()

            if not query:
                continue
            if query in ["quit", r"\q", "exit"]:
                return {"result": "sql_console exited.", "error": False}
            m = re.findall(r"^use\s+([\w_]+);?$", query, re.IGNORECASE)
            if len(m):
                database = m[0]
                self.args.update({"database": database})
                print(f"databse changed to {database}.")
                print(f"databse changed to {database}.")
                continue
            self.args["query"] = query
            result = self._query(vector, self.args)
            self.print_result(result)

    def print_result(self, result):
        if result["error"]:
            log.info(result["error"])

        if result["result"]:
            if type(result["result"]) is str:
                log.info(result["result"])
            else:
                Module.print_result(self, result["result"], header=True)

        elif not result["error"]:
            log.warn("%s %s" % (messages.module_sql_console.no_data, messages.module_sql_console.check_credentials))

            command_last_chars = utils.prettify.shorten(self.args["query"].rstrip(), keep_trailer=10)

            if command_last_chars and command_last_chars[-1] != ";":
                log.warn(messages.module_sql_console.missing_sql_trailer_s % command_last_chars)
