import itertools
from itertools import combinations
import re
import ast 
import math
from copy import deepcopy

# General notes:
# Elements of ground set are labeled 0, ..., n-1
# All vector and covectors are stored as tuples for hashing/ deleting duplicates purposes
# All sign lists are stored as arrays of 0s and 1s for easy permutation/ flipping

############# Converting between oriented matroid representation (bases, vectors, circuits, covectors) #####################

# Generates a list of size rk subsets of [0,..,n-1] in reverse lexicographic order
def generate_subsets(rk, n):
    # Generate all combinations of size rk
    subsets = list(combinations(list(range(0, n)), rk))
    
    # Sort subsets in reverse lexicographic order
    subsets.sort(key=lambda x: tuple(reversed(x)), reverse=True)
    subsets.reverse()

    return subsets

# Puts a list of sign vectors into a dictionary
# keys are bases (ie size k subsets of [n]) and values are the sign of that basis
def basis_signs_to_dict(rk, n, basis_signs):
    # Initialize basis
    bases = generate_subsets(rk, n)
    signStrings = list(basis_signs)

    # Creating signs
    signArray = [1 if sgn=='+' else -1 if sgn == "-" else 0 for sgn in signStrings]
    basis_signs_dict = dict(zip(bases, signArray))
    return basis_signs_dict

# Takes in oriented bases and produces oriented circuits
# by sgn(c_i) = (-1)^i sgn(C\c_i) for each c_i in the circuit
# Each circuit is a list with positive and negative parts
# e.g. [0,1,2,3] might be stored as [[0,1],[2,3]]
def orient_circuit(circuit, basis_dict):
    circuit_signs = [(-1)**(i+1)*basis_dict[tuple(sorted(circuit[:i] + circuit[i+1:]))] for i in range(len(circuit))]
    circuit_pos = tuple([circuit[i] for i, x in enumerate(circuit_signs) if x == 1])
    circuit_neg = tuple([circuit[i] for i, x in enumerate(circuit_signs) if x == -1])
    return (circuit_pos, circuit_neg)

def is_orthogonal(x,y):
    s = (set(x[0]) & set(y[1])) | (set(x[1]) & set(y[0]))
    t = (set(x[0]) & set(y[0])) | (set(x[1]) & set(y[1]))
    return ((len(s)>0) and (len(t)>0)) or ((len(s)==0) and (len(t)==0)) 

# Input k, n, an input file with a list of strings of basis orientations
# Output a file with a list of circuits in each matroid 
# USAGE: bases_to_circuits(3,5,"matroids35string.txt", "matroids35circuits.txt")
def bases_to_circuits_files(k,n, input_file, output_file):
    # Reading file
    with open(input_file, 'r+') as file:
        content = file.read()
    basis_strings = ast.literal_eval(content)    

    # Computing circuits for each matroid
    possible_circuits = generate_subsets(k+1, n)
    with open(output_file, 'w') as outfile:
        for basis_string in basis_strings:
            basis_dict = basis_signs_to_dict(k,n,basis_string)
            oriented_circuits = set([orient_circuit(circuit, basis_dict) for circuit in possible_circuits])
            oriented_circuits = oriented_circuits | {(b, a) for (a, b) in oriented_circuits}
            outfile.write(str(oriented_circuits)+",\n")

# Input: rank rk of matroid, number n of elements in ground set, 
# and a list of strings of basis orientations
# Output: a list of circuits in each matroid
def bases_to_circuits(rk, n, basis_string):
    basis_dict = basis_signs_to_dict(rk,n,basis_string)
    possible_circuits = generate_subsets(rk+1, n)
    oriented_circuits = set([orient_circuit(circuit, basis_dict) for circuit in possible_circuits])
    oriented_circuits = oriented_circuits | {(b, a) for (a, b) in oriented_circuits}
    return oriented_circuits

# Composes a vector circ_1 \comp ...\comp circ_{k-1} and a circuit circ_k
# Input eg: ((2,4),(5,3)),((1,2,5),())
# Output eg: 124 | 35
def compose(vector, circuit):
    pos = set(vector[0]) | set(circuit[0]).difference(set(vector[1]))
    neg = set(vector[1]) | set(circuit[1]).difference(set(vector[0]))
    return (tuple(pos),tuple(neg))

# Compute the vectors from the circuits
# Input eg: 5, {((1, 2, 5), ()), ((1, 3, 5), (4,)), ((2, 4), (1, 3)), ((2, 4, 5), (3,))}
# Output eg: {((1, 3, 5), (4,)), ((1, 2, 4, 5), (3,)), ((1, 2, 3, 5), (4,)), ((2, 4, 5), (3,)), 
# ((1, 2, 5), ()), ((2, 4, 5), (1, 3)), ((2, 4), (1, 3))}
def circuits_to_vectors(n, circuits):
    vectors = []
    new_vectors_no_dupes = circuits

    # Iteratively compose vectors and circuits until nothing new is gained
    while len(new_vectors_no_dupes)!=0:
        vectors.append(new_vectors_no_dupes)
        new_vectors = set([compose(vect, circ) for vect in vectors[-1] for circ in circuits])
        new_vectors_no_dupes = new_vectors.difference(vectors[-1])
    return set.union(*vectors)

# Generates all 2^n subdivisions of [0, n-1] into positive and negative parts
# eg: for division in all_orientations(3): print(division)
def all_orientations(n):
    elements = range(0, n)
    for r in range(n + 1):
        for subset in combinations(elements, r):
            complement = tuple(set(elements) - set(subset))
            yield (subset, complement)

def is_covect(candidate_covect, vectors):
    return all(is_orthogonal(vect, candidate_covect) for vect in vectors)

# Takes in vectors and produces max covectors
# by checking for each possible max signed set whether it's orthogonal to all vectors
def vectors_to_max_covectors(n,vectors):
    ors = all_orientations(n)
    max_covectors = {candidate_covect for candidate_covect in ors if is_covect(candidate_covect, vectors)}
    return max_covectors

# Change tuples to array of zeroes and ones
def max_covectors_to_sign_vectors(n, covectors):
    sign_vectors_list = []
    for covect in covectors:
        sign_vectors_list.append([0 if i in covect[0] else 1 for i in range(n)])
    return sign_vectors_list

# Putting everything together
def basis_to_sign_vectors(rk, n, sign_string):
    my_circs = bases_to_circuits(rk, n, sign_string)
    my_vects = circuits_to_vectors(n, my_circs)
    my_covects = vectors_to_max_covectors(n,my_vects)
    regions = remove_duplicate_regions(my_covects)
    my_sign_vectors = max_covectors_to_sign_vectors(n, regions)
    return my_sign_vectors

# If a list contains (a,b) and (b,a)
# Return a list which keep only one representative (a,b) of each pair
def remove_duplicate_regions(input_list):
    seen = set()
    output_list = []
    
    for (a, b) in input_list:
        if (b, a) not in seen:
            seen.add((a, b))
            output_list.append((a, b))

    return output_list

######################## Grasstope specific (manipulating regions etc) ###############################
    
############## Re-ordering and re-orienting ###############

# Permute a vector, where perm is given by a list of numbers from 0 to n-1
# eg permute_vector([3,4,5],(2,0,1))
# gives [5,3,4]
def permute_vector(sign_vector, perm):
    return [sign_vector[i] for i in perm]

# Flip_word tells you which hyperplanes to flip
# eg flip_vector([0,1,1,0],[2,3])
# gives [0,1,0,1]
# There are 2^n flip_words
def flip_vector(vector, flip_word):
    flipped_vector = deepcopy(vector)
    for index in flip_word:
        flipped_vector[index] = (vector[index]-1) % 2
    return flipped_vector

def permute_vectors(sign_vector_list, perm):
    permuted_vectors_list = [permute_vector(vector, perm) for vector in sign_vector_list]
    return permuted_vectors_list

def flip_vectors(sign_vector_list, flip_word):
    flipped_vectors_list = [flip_vector(vector, flip_word) for vector in sign_vector_list]
    return flipped_vectors_list

####### Counting regions in Grasstope ############

# Input eg: print(sign_var([0,1,0,0,1]))
# Output eg: 3
def sign_var(sign_vector):
    # Count the number of changes in sign_vector
    changes = sum(sign_vector[i] != sign_vector[i+1] for i in range(len(sign_vector)-1))
    return changes

# Returns set of covectors with sign variation >= k
def grasstope(k, sign_vector_list):
    grasstope = [vect for vect in sign_vector_list if sign_var(vect) >= k]
    return grasstope

def minmax_regions(k, n, sign_vector_list):
    # Initializing region counts
    gtope = grasstope(k, sign_vector_list)
    min_regions_gtope = gtope
    max_regions_gtope = gtope
    min_regions_code = (range(n),())
    max_regions_code = (range(n),())

    # Defining transformations
    perms = list(itertools.permutations(list(range(n))))
    flip_words = list(itertools.chain.from_iterable(itertools.combinations(range(n), r) for r in range(n + 1)))

    # Iterating over transformations
   # We flip and then permute 
    ind = 0
    for perm in perms:
        ind+=1
        for flip_word in flip_words:
            new_sign_vector_list = permute_vectors(flip_vectors(sign_vector_list, flip_word), perm)
            new_gtope = grasstope(k, new_sign_vector_list)
            if len(new_gtope) < len(min_regions_gtope): 
                min_regions_gtope = new_gtope
                min_regions_code = (perm, flip_word)
            elif len(new_gtope) > len(max_regions_gtope): 
                max_regions_gtope = new_gtope
                max_regions_code = (perm, flip_word)
                print("new maximum: " + str(len(new_gtope)))
                print("code word: " + str(max_regions_code))
        if (ind % 100) == 0 :
            print("iteration %d complete" %ind)
    return (min_regions_gtope, max_regions_gtope, min_regions_code, max_regions_code)

# Total regions for n hyperplanes in PR^{k}
def total_regions(k, n):
    r = sum(math.comb(n, i) for i in range(k+1))
    b = math.comb(n-1,k)
    return b + 0.5*(r - b)

# Number of vectors with sign variance less than k
def beta(k, n):
    return sum(math.comb(n-1, i) for i in range(k))

# For amplituhedra, can go straight to max covectors = signed sets with variation <= k
def gen_amp_vects(k, n):
    all_signvectors = generate_arrays(n)
    amp_regions = [vect for vect in all_signvectors if matroidtograsstope.sign_var(vect) <= k]
    return amp_regions

# Generate all 0/1 arrays of size n
def generate_arrays(n):
    # Generate all numbers up to 2^(n-1)
    for i in range(2 ** (n - 1)):
        # Convert to binary and remove the '0b' prefix
        binary = bin(i)[2:]

        # Pad with leading zeros to reach n-1 digits
        binary = binary.zfill(n - 1)

        # Convert the binary string to a list of integers
        array = [int(digit) for digit in binary]

        # Prepend a 0 and yield the array
        yield [0] + array

######## Outputting and formatting Grasstopes data #########

def print_full_grasstope_data(rk, n, sign_string):
    print("Rank = %s, Elts = %s, Basis signs = %s :" %(rk, n, sign_string))
    my_circs = bases_to_circuits(rk, n, sign_string)
    print("Circuits: " + str(my_circs))
    my_vects = circuits_to_vectors(n, my_circs)
    my_covects = vectors_to_max_covectors(n,my_vects)
    regions = remove_duplicate_regions(my_covects)
    my_sign_vectors = max_covectors_to_sign_vectors(n, regions)
    print("Regions:" + str(my_sign_vectors))
    print("Number of regions: " +str(len(my_sign_vectors)))
    gtope = grasstope(rk-1,my_sign_vectors)
    print("Grasstope: " + str(gtope))
    print("Number regions in grasstope: " + str(len(gtope)))
    extremal_counts = minmax_regions(rk-1, n, my_sign_vectors)
    print("Min and max regions: " + str((len(extremal_counts[0]), len(extremal_counts[1]))))
    #print("Grasstope acheiving min: " + str(extremal_counts[0]))
    print("Grasstope acheiving max: " + str(extremal_counts[3]))
    # NOTE record this as permutations / flips rather than covectrs

# USAGE: minmax_regions_files(4, 7, "m47generic.txt", "m47extremal.txt")
def minmax_regions_files(rk, n, input_file, output_file):
    # Reading file
    with open(input_file, 'r+') as file:
        content = file.read()
    basis_strings = ast.literal_eval(content) 

    with open(output_file, 'w') as outfile:
        for sign_string in basis_strings:
            my_sign_vectors = basis_to_sign_vectors(rk, n, sign_string)
            extremal = minmax_regions(rk-1, n, my_sign_vectors)   
            outfile.write("%s: (%d, %d)\n" %(sign_string, len(extremal[0]), len(extremal[1])))

# Converts a vector (in the linear algebra sense) into its signs (+ and -)
def vector_to_signs(vect):
    sign_list = ['+' if x > 0 else '-' if x < 0 else 0 for x in vect]
    return ''.join(sign_list)

####################### Tests ################################

# Gives counts for the class of the amplituhedron
def amplituhedron(k, n):
    signstrings = gen_amp_vects(k, n)
    print("number of regions: " + str(len(signstrings)))
    print("number of high variance vectors: " + str(2**(n-1) - beta(k, n)))
    print("number of regions in amplituhedron: "+ str(len(grasstope(k, signstrings))))
    extremal_counts = minmax_regions(k, n, signstrings)
    print("Min and max regions: " + str((len(extremal_counts[0]), len(extremal_counts[1]))))
    print("code: " + str(extremal_counts[3]))

# Gives counts for another matroid whose extremal values are (20, 38)
def non_amplithedron():
    bstring = "++++++++++++++++++++++++-++++------"
    sign_vects = basis_to_sign_vectors(3, 7, bstring)
    extremal_counts= minmax_regions(3, 7, sign_vects)
    print("Upper bound: %d" %total_regions(3, 7))
    print("Min and max regions: " + str((len(extremal_counts[0]), len(extremal_counts[1]))))

# Example with files
def extremal_counts_files():
    delete_text_between_I_and_equal("ormat_data/m37uniform.txt")
    quote_lines("ormat_data/m37uniform.txt", "ormat_data/m37uniform.txt")
    minmax_regions_files(3, 7, "ormat_data/m37uniform.txt", "ormat_data/m37counts.txt")

# Gives counts for the FMR matroid
def FMR():
    n = 8
    k = 3

    cocircuits = []
    words = split_file("nonorientable.txt")
    for word in words:
        cocircuits.append(split_number_string(word))

    covectors = circuits_to_vectors(8, set(cocircuits))
    max_covectors = [covect for covect in covectors if (len(covect[0]) + len(covect[1])) == 8]
    regions = remove_duplicate_regions(max_covectors)
    my_sign_vectors = max_covectors_to_sign_vectors(8, regions)
    # print(my_sign_vectors)

    minmax = minmax_regions(k, n, my_sign_vectors)
    print("(%d, %d)\n" %(len(minmax[0]), len(minmax[1])))

############## Helper functions for formatting the oriented matroids data from the database #####################

# Removes the DT....= part
# And puts the sign vectors (as strings) into a list
def delete_text_between_D_and_equal(filename):
    with open(filename, 'r+') as file:
        content = file.read()
        modified_content = re.sub(r'D[^=]*=\s*', '', content)
        file.seek(0)
        file.write(modified_content)
        file.truncate()

def delete_text_between_I_and_equal(filename):
    with open(filename, 'r+') as file:
        content = file.read()
        modified_content = re.sub(r'I[^=]*=\s*', '', content)
        file.seek(0)
        file.write(modified_content)
        file.truncate()

def quote_lines(input_filename, output_filename):
    with open(input_filename, 'r') as infile:
        lines = ["{}".format(line.strip()) for line in infile]

    with open(output_filename, 'w') as outfile:
        outfile.write(str(lines))

# Handles covectors in the file nonorientable.txt
def split_file(file_path):
    with open(file_path, 'r') as f:
        content = f.read()
        word_list = content.split()
    return word_list

def split_number_string(s):
    if '.' in s:
        integer_part, decimal_part = s.split('.')
        decimal_part = tuple(int(digit) for digit in decimal_part)
    else:
        integer_part = s
        decimal_part = ()

    integer_part = tuple(int(digit) for digit in integer_part)
    return (integer_part, decimal_part)

# Usage:
# delete_text_between_D_and_equal("matroids47.txt") 
# quote_lines("matroids47.txt", "matroids47string.txt")





