import os
import pandas as pd
from signac import JSONDict
from pathlib import Path

class Aggregate:
    def __init__(self,path):
        # path is a path to the top directory of the aggregate
        self.path = Path(path)
        self.id = self.path.stem
        with open(self.path/"aggregate_jobs.txt",'r') as f:
            self.jobs = f.readlines()
            self.jobs = [j.strip() for j in self.jobs]
        self.fn_sp = "aggregate_statepoint.json"
        self.fn_doc = "aggregate_document.json"
        self._sp_ = None
        self._doc_ = None
    
    @property
    def _sp(self):
        if self._sp_ is None:
            self._sp_ = JSONDict(filename=self.path/self.fn_sp, write_concern=True)
        return self._sp_
    
    @property 
    def _doc(self):
        if self._doc_ is None:
            self._doc_ = JSONDict(filename=self.path/self.fn_doc, write_concern=True)
        return self._doc_
    
    def __enter__(self):
        self.origpath = os.getcwd()
        os.chdir(self.path)
        return

    def __exit__(self, excep_type, excep_value, excep_traceback):
        os.chdir(self.origpath)
        return

# A class that creates a pandas dataframe based on aggregate data             
class AggregateDataset:

    def __init__(self,root):
        # root is the root path containing folders named agg-IDHASH
        self.root = Path(root)
        subdirs = [d for d in self.root.iterdir() if d.is_dir()]
        self.aggregates = [Aggregate(d) for d in subdirs]
        self._aggids = [a.id for a in self.aggregates]

        self.data = self._create_dataframe()
        for col in self.data.columns:
            if self.data[col].dtype=="object":
                self.data[col].apply(lambda x: tuple(x))
        return

    def aggbyid(self,aggid):
        agg_idx = self._aggids.index(aggid)
        return self.aggregates[agg_idx]

    def _create_dataframe(self, usecols = None):
        sp_prefix = "sp."
        doc_prefix = "doc."

        if usecols is None:
            def usecols(col):
                return True
        elif not callable(usecols):
            included_columns = set(usecols)
            def usecols(col):
                return col in included_columns

        data = {}
        for a in self.aggregates:
            aggdata = {}
            for k,v in a._sp.items():
                prefixed_key = sp_prefix+k
                if usecols(prefixed_key):
                    aggdata[prefixed_key] = v
            for k,v in a._doc.items():
                prefixed_key = doc_prefix+k
                if usecols(prefixed_key):
                    aggdata[prefixed_key] = v
            data[a.id] = aggdata
        
        return pd.DataFrame.from_dict(
            data=data, orient="index",
        ).infer_objects()
        
