#include <string>
#include <memory>
#include <map>
#include <numeric>
#include <tuple>
#include <unordered_set>
#include <cmath>

#include "matrice.h"
#include "separation.h"

// ***********************************************
// ***********************************************
// ***********************************************

void GeneralFuzzyInBetweenness(std::uint_fast64_t pi, std::uint_fast64_t qi, std::uint_fast64_t ri,
                                 std::shared_ptr<Matrice<double>> dominance,
                                 NormConorm& times,
                                 NormConorm& plus,
                                 double& finb_prq, double& finb_qrp, double& finbqrp) {

    double sdom_pr = (pi != ri) * dominance->at(pi, ri);
    double sdom_rq = (qi != ri) * dominance->at(ri, qi);
    finb_prq = times(times(dominance->at(pi, qi), sdom_pr), sdom_rq);
        
    double sdom_qr = (qi != ri) * dominance->at(qi, ri);
    double sdom_rp = (pi != ri) * dominance->at(ri, pi);
    finb_qrp = times(times(dominance->at(qi, pi), sdom_qr), sdom_rp);
        
    finbqrp = plus(finb_prq, finb_qrp);
};


// ***********************************************
// ***********************************************
// ***********************************************

void CumulativeFuzzyInBetweenness(std::uint_fast64_t pi, std::uint_fast64_t qi,
                                 std::shared_ptr<Matrice<double>> dominance,
                                 NormConorm& times,
                                 NormConorm& plus,
                                 double& cfinbpq, double& cfinb_pq, double& cfinb_qp) {
    cfinbpq = 0.0;
    cfinb_pq = 0.0;
    cfinb_qp = 0.0;
    for (std::uint_fast64_t ri = 0; ri < dominance->Rows(); ++ri) {
        double finb_prq;
        double finb_qrp;
        double finbqrp;
        GeneralFuzzyInBetweenness(pi, qi, ri, dominance, times, plus, finb_prq, finb_qrp, finbqrp);
        cfinb_pq += finb_prq;
        cfinb_qp += finb_qrp;
        cfinbpq += finbqrp;
    }
};

// ***********************************************
// ***********************************************
// ***********************************************

std::tuple<std::shared_ptr<Matrice<double>>, std::shared_ptr<Matrice<double>>, std::shared_ptr<Matrice<double>>>
GeneralSeparation(
                  std::shared_ptr<Matrice<double>> dominance,
                  NormConorm& times,
                  NormConorm& plus,
                  bool do_all, bool do_lower, bool do_upper) {
    
    std::shared_ptr<Matrice<double>> sep_lower = (do_lower ? std::make_shared<Matrice<double>>(dominance->Rows(), dominance->Cols()) : nullptr);
    std::shared_ptr<Matrice<double>> sep_upper = (do_upper ? std::make_shared<Matrice<double>>(dominance->Rows(), dominance->Cols()) : nullptr);
    std::shared_ptr<Matrice<double>> sep_all = (do_all ? std::make_shared<Matrice<double>>(dominance->Rows(), dominance->Cols()) : nullptr);
    
    for (std::uint_fast64_t pi = 0; pi < dominance->Rows(); ++pi) {
        for (std::uint_fast64_t qi = pi + 1; qi < dominance->Rows(); ++qi) {
            double cfinbpq;
            double cfinb_pq;
            double cfinb_qp;
            CumulativeFuzzyInBetweenness(pi, qi,
                                        dominance, times, plus,
                                        cfinbpq, cfinb_pq, cfinb_qp);

            if (do_all) {
                (*sep_all)(pi, qi) = (*sep_all)(qi, pi) = dominance->at(pi, qi) + dominance->at(qi, pi) + cfinbpq;
            }
            if (do_lower) {
                (*sep_lower)(pi, qi) = dominance->at(pi, qi) + cfinb_pq;
                (*sep_lower)(qi, pi) = dominance->at(qi, pi) + cfinb_qp;
            }
            if (do_upper) {
                (*sep_upper)(pi, qi) = dominance->at(qi, pi) + cfinb_qp;
                (*sep_upper)(qi, pi) = dominance->at(pi, qi) + cfinb_pq;
            }
        }
    }
    return std::make_tuple(sep_all, sep_lower, sep_upper);
    
}

// ***********************************************
// ***********************************************
// ***********************************************

std::tuple<double, double, double> LexSeparation(std::uint_fast64_t numero_variabili, std::uint_fast64_t numero_modalita, std::vector<std::uint_fast64_t>& p, std::vector<std::uint_fast64_t>& q) {
    
    
    auto compare = [](std::vector<std::uint_fast64_t>& p, std::vector<std::uint_fast64_t>& q) {
        auto p_minore_q = std::make_shared<std::unordered_set<std::uint_fast64_t>>();
        auto p_uguale_q = std::make_shared<std::unordered_set<std::uint_fast64_t>>();
        auto p_maggiore_q = std::make_shared<std::unordered_set<std::uint_fast64_t>>();
        
        for (std::uint_fast64_t k = 0; k < p.size(); ++k) {
            if (p.at(k) < q.at(k)) {
                p_minore_q->insert(k);
            } else if (p.at(k) == q.at(k)) {
                p_uguale_q->insert(k);
            } else {
                p_maggiore_q->insert(k);
            }
        }
        return std::make_tuple(p_minore_q, p_uguale_q, p_maggiore_q);
    };
    
    
    
    double k = numero_variabili;
    auto c = compare(p, q);
    auto G3 = std::get<0>(c);
    auto G2 = std::get<1>(c);
    auto G1 = std::get<2>(c);

    double k1 = G1->size();
    double k2 = G2->size();
    double k3 = G3->size();

    double F1 = 0; // 𝐹1(𝑚,𝑘,𝑘2)
    double F2 = 0; // 𝐹2(𝑚,𝑘,𝑘2)
    if (k2 <= k - 1) {
        for (std::uint_fast64_t j = 0; j <= k2; ++j) {
            double num1 = std::tgamma((k - j - 1) + 1);
            double den = std::tgamma((k2 - j) + 1);
            double prod1 = std::pow(numero_modalita, k - j - 1);
            F1 += ((num1 / den) * prod1);
            if (k2 <= k - 2) {
                double num2 = num1 / (k - j - 1);
                double inner_sum = 0;
                for (std::uint_fast64_t h = j + 2; h <= numero_variabili; ++h) {
                    inner_sum += std::pow(numero_modalita, k - h);
                }
                F2 += ((num2 / den) * inner_sum);
            }
        }
    }
    
    double T_plus = 0;
    for (auto v : (*G1)) {
        T_plus += (((double) p.at(v)) - q.at(v));
    }
    double T_minius = 0;
    for (auto v : (*G3)) {
        T_minius += (((double) p.at(v)) - q.at(v));
    }

    double k2_fatt = std::tgamma(k2 + 1);
    double sep_all = (k2_fatt * (((T_plus - T_minius) * F1) + ((T_plus * (k1 - k3 - 1) + T_minius * (k1 - k3 + 1)) * F2))) / std::tgamma(k + 1);
    double sep_upper = (k2_fatt * (T_plus * F1 + (T_plus * (k1 - 1) + T_minius * k1) * F2)) / std::tgamma(k + 1);
    double sep_lower = (-k2_fatt * (T_minius * F1 +(T_minius * (k3 - 1) + T_plus * k3) * F2)) / std::tgamma(k + 1);
    
    return {sep_all, sep_lower, sep_upper};
}

// ***********************************************
// ***********************************************
// ***********************************************

std::tuple<
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<std::vector<std::vector<std::uint_fast64_t>>>
>
LexSeparationEqDeg(std::uint_fast64_t numero_variabili, std::uint_fast64_t numero_modalita) {
    std::uint_fast64_t numero_profili = std::pow(numero_modalita, numero_variabili);
    auto sep_all = std::make_shared<Matrice<double>>(numero_profili, numero_profili, std::numeric_limits<double>::quiet_NaN());
    auto sep_lower = std::make_shared<Matrice<double>>(numero_profili, numero_profili, std::numeric_limits<double>::quiet_NaN());
    auto sep_upper = std::make_shared<Matrice<double>>(numero_profili, numero_profili, std::numeric_limits<double>::quiet_NaN());
    auto sep_vertical = std::make_shared<Matrice<double>>(numero_profili, numero_profili, std::numeric_limits<double>::quiet_NaN());
    auto sep_horizontal = std::make_shared<Matrice<double>>(numero_profili, numero_profili, std::numeric_limits<double>::quiet_NaN());

    auto profili = std::make_shared<std::vector<std::vector<std::uint_fast64_t>>>(numero_profili, std::vector<std::uint_fast64_t>(numero_variabili, 0));

    {
        std::uint_fast64_t modalita = 1;
        for (std::uint_fast64_t profilo_id = 1; profilo_id < numero_profili; ++profilo_id) {
            auto& profilo = profili->at(profilo_id);
            auto& profilo_prec = profili->at(profilo_id - 1);

            if (modalita == numero_modalita) {
                std::uint_fast64_t p = numero_variabili - 2;
                for (; ; --p) {
                    auto v = profilo_prec.at(p);
                    if (v != numero_modalita - 1) {
                        profilo.at(p) = profilo_prec.at(p) + 1;
                        for (std::uint_fast64_t k = 0; k < p; ++k) {
                            profilo.at(k) = profilo_prec.at(k);
                        }
                        break;
                    }
                    if (p == 0) {
                        break;
                    }
                }
                profilo.at(numero_variabili - 1) = 0;
                modalita = 1;
            } else {
                for (std::uint_fast64_t p = 0; p < numero_variabili - 1; ++p) {
                    profilo.at(p) = profilo_prec.at(p);
                }
                profilo.at(numero_variabili - 1) = modalita;
                ++modalita;
            }
        }
    }
    
    
    for (std::uint_fast64_t pi = 0; pi < numero_profili; ++pi) {
        for (std::uint_fast64_t qi = 0; qi < numero_profili; ++qi) {
            auto& p = profili->at(pi);
            auto& q = profili->at(qi);
            auto sep_pq = LexSeparation(numero_variabili, numero_modalita, p, q);
            auto sep_all_pq = std::get<0>(sep_pq);
            auto sep_lower_pq = std::get<1>(sep_pq);
            auto sep_upper_pq = std::get<2>(sep_pq);
            auto sep_vertical_pq = std::abs(sep_lower_pq - sep_upper_pq);
            auto sep_horizontal_pq = sep_all_pq - sep_vertical_pq;
            (*sep_all)(pi, qi) = sep_all_pq;
            (*sep_lower)(pi, qi) = sep_lower_pq;
            (*sep_upper)(pi, qi) = sep_upper_pq;
            (*sep_vertical)(pi, qi) = sep_vertical_pq;
            (*sep_horizontal)(pi, qi) = sep_horizontal_pq;
        }
    }
    
    return {sep_all, sep_lower, sep_upper, sep_vertical, sep_horizontal, profili};
}


// ***********************************************
// ***********************************************
// ***********************************************

std::tuple<
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<std::vector<std::vector<std::uint_fast64_t>>>
>
LexSeparationDeg(std::vector<std::uint_fast64_t>& numero_modalita) {
    std::uint_fast64_t numero_profili = std::accumulate(numero_modalita.begin(), numero_modalita.end(), 1.0, std::multiplies<double>());
    std::uint_fast64_t numero_variabili = numero_modalita.size();
    auto sep_all = std::make_shared<Matrice<double>>(numero_profili, numero_profili, 0);
    auto sep_lower = std::make_shared<Matrice<double>>(numero_profili, numero_profili, 0);
    auto sep_upper = std::make_shared<Matrice<double>>(numero_profili, numero_profili, 0);
    auto sep_vertical = std::make_shared<Matrice<double>>(numero_profili, numero_profili, 0);
    auto sep_horizontal = std::make_shared<Matrice<double>>(numero_profili, numero_profili, 0);

    auto calcolaLLE = [](std::vector<std::uint_fast64_t>& variable_priority, std::vector<std::uint_fast64_t>& numero_modalita, std::vector<std::vector<std::uint_fast64_t>>& lle) {
        std::uint_fast64_t numero_variabili = variable_priority.size();

        std::uint_fast64_t modalita = 1;
        auto last_v = variable_priority.at(numero_variabili - 1);
        
        lle.at(0) = std::vector<std::uint_fast64_t>(numero_variabili, 0);
        
        for (std::uint_fast64_t profilo_id = 1; profilo_id < lle.size(); ++profilo_id) {
            auto& profilo = lle.at(profilo_id);
            auto& profilo_prec = lle.at(profilo_id - 1);

            if (modalita == numero_modalita.at(last_v)) {
                std::uint_fast64_t p = numero_variabili - 2;
                for (; ; --p) {
                    auto var = variable_priority.at(p);
                    auto v = profilo_prec.at(var);
                    if (v != numero_modalita.at(var) - 1) {
                        profilo.at(var) = profilo_prec.at(var) + 1;
                        for (std::uint_fast64_t k = 0; k < p; ++k) {
                            auto var_k = variable_priority.at(k);
                            profilo.at(var_k) = profilo_prec.at(var_k);
                        }
                        for (std::uint_fast64_t k = p + 1; k < numero_variabili; ++k) {
                            auto var_k = variable_priority.at(k);
                            profilo.at(var_k) = 0;
                        }
                        break;
                    }
                    if (p == 0) {
                        break;
                    }
                }
                profilo.at(last_v) = 0;
                modalita = 1;
            } else {
                for (std::uint_fast64_t p = 0; p < numero_variabili - 1; ++p) {
                    auto var_p = variable_priority.at(p);
                    profilo.at(var_p) = profilo_prec.at(var_p);
                }
                profilo.at(last_v) = modalita;
                ++modalita;
            }
        }
    };
    
    auto variable_priority = std::vector<std::uint_fast64_t>(numero_variabili);
    std::iota(variable_priority.begin(), variable_priority.end(), 0);
    auto profili = std::make_shared<std::vector<std::vector<std::uint_fast64_t>>>(numero_profili, std::vector<std::uint_fast64_t>(numero_variabili, 0));
    calcolaLLE(variable_priority, numero_modalita, *profili);
    std::map<std::vector<std::uint_fast64_t>, std::uint_fast64_t> posizione_profili;
    for (std::uint_fast64_t p_id = 0; p_id < profili->size(); ++p_id) {
        posizione_profili[profili->at(p_id)] = p_id;
    }

    double numero_lex = std::tgamma(numero_variabili + 1);
    std::vector<std::vector<std::uint_fast64_t>> lle(numero_profili, std::vector<std::uint_fast64_t>(numero_variabili, 0));
    do {
        calcolaLLE(variable_priority, numero_modalita, lle);
        {
            // per ogni coppia di profili in lle calcolo la separazione
            for (std::uint_fast64_t p_id = 0; p_id < lle.size(); ++p_id) {
                for (std::uint_fast64_t q_id = p_id + 1; q_id < lle.size(); ++q_id) {
                    double distanza = std::abs(((double) q_id) - p_id);
                    auto&p = lle.at(p_id);
                    auto&q = lle.at(q_id);
                    std::uint_fast64_t sep_p_id = posizione_profili.at(p);
                    std::uint_fast64_t sep_q_id = posizione_profili.at(q);
                    
                    (*sep_all)(sep_p_id, sep_q_id) += (distanza / numero_lex);
                    (*sep_all)(sep_q_id, sep_p_id) = (*sep_all)(sep_p_id, sep_q_id);
                    
                    (*sep_lower)(sep_p_id, sep_q_id) += (distanza / numero_lex);
                    (*sep_upper)(sep_q_id, sep_p_id) += (distanza / numero_lex);


                }
            }
        }
    } while (std::next_permutation(variable_priority.begin(), variable_priority.end()));
    
    {
        // per ogni coppia di profili in lle calcolo la separazione
        for (std::uint_fast64_t p_id = 0; p_id < sep_all->Rows(); ++p_id) {
            for (std::uint_fast64_t q_id = 0; q_id < sep_all->Rows(); ++q_id) {
                auto sep_all_pq = (*sep_all)(p_id, q_id);
                auto sep_lower_pq = (*sep_lower)(p_id, q_id);
                auto sep_upper_pq = (*sep_upper)(p_id, q_id);
                auto sep_vertical_pq = std::abs(sep_lower_pq - sep_upper_pq);
                auto sep_horizontal_pq = sep_all_pq - sep_vertical_pq;
                (*sep_vertical)(p_id, q_id) = sep_vertical_pq;
                (*sep_horizontal)(p_id, q_id) = sep_horizontal_pq;
            }
        }
    }

    return {sep_all, sep_lower, sep_upper, sep_vertical, sep_horizontal, profili};
}

// ***********************************************
// ***********************************************
// ***********************************************

std::tuple<
    std::shared_ptr<Matrice<double>>,
    std::shared_ptr<std::vector<std::vector<std::uint_fast64_t>>>
>
LexMRP(std::vector<std::uint_fast64_t>& numero_modalita) {
    std::uint_fast64_t numero_profili = std::accumulate(numero_modalita.begin(), numero_modalita.end(), 1.0, std::multiplies<double>());
    std::uint_fast64_t numero_variabili = numero_modalita.size();
    
    auto mrp = std::make_shared<Matrice<double>>(numero_profili, numero_profili, 0.0);
    
    
    auto quanti = [](std::vector<std::uint_fast64_t>& p, std::vector<std::uint_fast64_t>& q) {
        std::uint_fast64_t p_minore_q = 0;
        std::uint_fast64_t p_uguale_q = 0;
        std::uint_fast64_t p_maggiore_q = 0;
        
        for (std::uint_fast64_t k = 0; k < p.size(); ++k) {
            if (p.at(k) < q.at(k)) {
                ++p_minore_q;
            } else if (p.at(k) == q.at(k)) {
                ++p_uguale_q;
            } else {
                ++p_maggiore_q;
            }
        }
        return std::make_tuple(p_minore_q, p_uguale_q, p_maggiore_q);
    };

    auto profili = std::make_shared<std::vector<std::vector<std::uint_fast64_t>>>(numero_profili, std::vector<std::uint_fast64_t>(numero_variabili, 0));
    std::map<std::vector<std::uint_fast64_t>, std::uint_fast64_t> p_to_profili;

    {
        p_to_profili[std::vector<std::uint_fast64_t>(numero_variabili, 0)] = 0;
        
        std::uint_fast64_t modalita = 1;
        for (std::uint_fast64_t profilo_id = 1; profilo_id < numero_profili; ++profilo_id) {
            auto& profilo = profili->at(profilo_id);
            auto& profilo_prec = profili->at(profilo_id - 1);

            if (modalita == numero_modalita.at(numero_variabili - 1)) {
                std::uint_fast64_t p = numero_variabili - 2;
                for (; ; --p) {
                    auto v = profilo_prec.at(p);
                    if (v != numero_modalita.at(p) - 1) {
                        profilo.at(p) = profilo_prec.at(p) + 1;
                        for (std::uint_fast64_t k = 0; k < p; ++k) {
                            profilo.at(k) = profilo_prec.at(k);
                        }
                        break;
                    }
                    if (p == 0) {
                        break;
                    }
                }
                profilo.at(numero_variabili - 1) = 0;
                modalita = 1;
            } else {
                for (std::uint_fast64_t p = 0; p < numero_variabili - 1; ++p) {
                    profilo.at(p) = profilo_prec.at(p);
                }
                profilo.at(numero_variabili - 1) = modalita;
                ++modalita;
            }
            p_to_profili[profilo] = profilo_id;
        }
    }
    double k = numero_variabili;
    for (std::uint_fast64_t p_id = 0; p_id < profili->size(); ++p_id) {
        for (std::uint_fast64_t q_id = p_id + 1; q_id < profili->size(); ++q_id) {
            auto&p = profili->at(p_id);
            auto&q = profili->at(q_id);
            auto r = quanti(p, q);

            double k1 = std::get<0>(r);
            double k2 = std::get<1>(r);
            
            for (std::uint_fast64_t s = 0; s <= k2; ++s) {
                double k_s_1_f = std::tgamma(k - s);
                double k2_s_f = std::tgamma(k2 - s + 1);

                (*mrp)(p_id, q_id) += (k_s_1_f / k2_s_f);
            }
            auto k2_f = std::tgamma(k2 + 1);
            auto k_f = std::tgamma(k + 1);
            (*mrp)(p_id, q_id) = (k1 * k2_f * (*mrp)(p_id, q_id)) / k_f;
            (*mrp)(q_id, p_id) = 1.0 - (*mrp)(p_id, q_id);
        }
    }
    //std::cout << "LLEMrp:\n" << mrp->to_string() << std::endl << std::endl;

    return {mrp, profili};
}

