import numpy as np
import pandas as pd
from mm_stats.definitions import DUMP_PATH
from mm_stats.auth import PostgresDb
from mm_stats.definitions import logger


def create_table(_filter: str):
    """Create a table for a Lorenz curve plot of user contributions."""
    db = PostgresDb()
    sql = "set schema '{_filter}';".format(_filter=_filter)
    where_clause = ["", "where action = 'MAPPED'", "where action = 'VALIDATED'"]
    table_names = ["_all", "_mapped", "_validated"]

    for x in range(0, len(where_clause)):
        sql_query = f"""
            SELECT count(*)
            FROM sessions
            {where_clause[x]}
            GROUP BY userid
            ORDER BY count
            """
        where_sql = sql + sql_query

        data = list(db.retr_query(where_sql))
        data = [float(y[0]) for y in data]
        x_data, y_data = lorenz_curve(data)
        df = pd.DataFrame({"x": x_data, "y": y_data})
        dump = DUMP_PATH + "/lorenz.csv"
        df.to_csv(dump, sep="\t", index=False, header=False)
        table_name = "users_lorenz_curve{}".format(table_names[x])
        db.query(
            """
            drop table if exists {};
            create table {} (x float, y float);
            """.format(
                table_name, table_name
            )
        )

        with open(dump, "r") as f:
            db.copy_from(f, table_name, ["x", "y"])

    logger.info(f"{_filter}: created table users_lorenz_curve")


def x_element(num, n_th):
    """Calculate which entries should be used for dense curve."""
    index_list = []
    x = num / n_th
    cum_x = 0
    for y in range(0, n_th):
        index_list.append(round(cum_x))
        cum_x += x
    return index_list


def lorenz_curve(x: np.array) -> tuple:
    """Calculate Lorenz curve and limit to 1000 entries."""
    x = np.array(x)
    y_data = x.cumsum() / x.sum()
    down_sampled = []

    if len(y_data) > 1000:
        # resample to 1000 points
        index_list = x_element(y_data.size, 999)
    else:
        # use the available data
        index_list = range(0, len(y_data))

    for i in index_list:
        down_sampled.append(y_data[i])
    down_sampled.append(y_data[y_data.size - 1])
    y_data = np.array(down_sampled)
    x_data = np.arange(y_data.size) / (y_data.size - 1)
    return x_data, y_data
