from abc import abstractmethod
from collections import defaultdict
import collections
import sys
from typing import Dict, List, Any, Optional, Set
import os
import json

import yaml
import jinja2
import jsonschema
import jsonschema.exceptions

ISSUE_UNKNOWN_UNIT = "unknown-unit"
ISSUE_UNKNOWN_INGREDIENT = "unknown-ingredient"
ISSUE_DUPLICATE_UNITS = "duplicate-units"
ISSUE_KNOWN_PRICE_UNKNOWN_CONVERSION = "known-price-unknown-conversion"


class Issue:
    def __init__(self, id: str, msg: str) -> None:
        self.id = id
        self.msg = msg


class Issues:
    def __init__(self) -> None:
        self.errors: List[Issue] = []
        self.warnings: List[Issue] = []

    def error(self, id: str, msg: str) -> None:
        self.errors.append(Issue(id, msg))

    def warn(self, id: str, msg: str) -> None:
        self.warnings.append(Issue(id, msg))

    def check(self) -> int:
        retcode = len(self.errors) != 0

        for msg in self.errors:
            print(f"ERROR {msg.id}: {msg.msg}")
        for msg in self.warnings:
            print(f"WARNING {msg.id}: {msg.msg}")

        self.errors.clear()
        self.warnings.clear()
        return retcode


class Context:
    def __init__(self) -> None:
        self.units: AUnits = FakeUnits(self)
        self.default_unit = Unit(self, {"name": "piece"})
        self.ingredients: AIngredients = FakeIngredients(self)
        self.issues = Issues()

    def load_units(
        self, unitsdct: List[Dict[str, Any]], unitsschema: Dict[str, Any]
    ) -> None:
        self.units = Units(self)
        self.units.units.append(self.default_unit)
        jsonschema.validate(instance=unitsdct, schema=unitsschema)
        self.units.load(unitsdct)
        self.units.validate()

    def load_ingredients(
        self, ingredientsdct: List[Dict[str, Any]], ingredientsschema: Dict[str, Any]
    ) -> None:
        self.ingredients = Ingredients(self)
        jsonschema.validate(instance=ingredientsdct, schema=ingredientsschema)
        self.ingredients.load(ingredientsdct)


class Element:
    def __init__(self, ctx: Context, dct: Dict[str, Any]) -> None:
        self.ctx = ctx
        self.dct = dct
        self.load(dct)
        for elem in self.dct.values():
            if isinstance(elem, dict):
                raise RuntimeError("Something wasn't processed properly")

    def __contains__(self, item: Any) -> bool:
        return item in self.dct

    def __getitem__(self, key: str) -> Any:
        return self.dct[key]

    def __setitem__(self, key: str, val: Any) -> None:
        self.dct[key] = val

    def __repr__(self) -> str:
        return self.dct.__repr__()

    @abstractmethod
    def load(self, dct: Dict[str, Any]) -> None:
        ...


class Conversion(Element):
    def load(self, dct: Dict[str, Any]) -> None:
        fromunit = self.ctx.units.get(dct["from"])
        if fromunit is None:
            raise RuntimeError(f"unit {dct['from']} doesn't exist")
        self["from"] = fromunit

        tounit = self.ctx.units.get(dct["to"])
        if tounit is None:
            raise RuntimeError(f"unit {dct['to']} doesn't exist")
        self["to"] = tounit


class Unit(Element):
    def load(self, dct: Dict[str, Any]) -> None:
        oldunits = self.ctx.units.units[:]
        self.ctx.units.units.append(self)

        aliases: List[str] = []
        if "aliases" in dct:
            for alias in dct["aliases"]:
                aliases.append(alias)
        self["aliases"] = aliases
        self.ctx.units.units = oldunits

    def finish_load(self) -> None:
        conversions: List[Conversion] = []
        if "conversions" in self.dct:
            for convdct in self.dct["conversions"]:
                if "from" in self.dct["conversions"]:
                    raise RuntimeError(
                        "conversions in units.yaml cannot have a from field, it is automatically assigned from the unit name"
                    )
                convdct["from"] = self["name"]
                conversion = Conversion(self.ctx, convdct)
                conversions.append(conversion)
        self["conversions"] = conversions


class AUnits:
    def __init__(self, ctx: Context) -> None:
        self.ctx = ctx
        self.units: List[Unit] = []

    def load(self, lst: List[Any]) -> None:
        pass

    @abstractmethod
    def get(self, name: str) -> Optional[Unit]:
        ...

    def validate(self) -> None:
        pass


class FakeUnits(AUnits):
    def get(self, name: str) -> Optional[Unit]:
        for unit in self.units:
            if unit["name"] == name:
                return unit
        unit = Unit(self.ctx, {"name": name})
        self.units.append(unit)
        return unit


class Units(AUnits):
    def load(self, lst: List[Any]) -> None:
        for unitdct in lst:
            unit = Unit(self.ctx, unitdct)
            self.units.append(unit)
        for unit in self.units:
            unit.finish_load()

    def get(self, name: str) -> Optional[Unit]:
        for unit in self.units:
            if unit["name"] == name or "aliases" in unit and name in unit["aliases"]:
                return unit
        return None

    def validate(self) -> None:
        unitnames = []
        for unit in self.units:
            unitnames.append(unit["name"])
        for unitname, num in collections.Counter(unitnames).items():
            if num > 1:
                self.ctx.issues.error(
                    ISSUE_DUPLICATE_UNITS,
                    f"units.yaml: {unitname} should only have one entry, found {num}",
                )


class Ingredient(Element):
    def load(self, dct: Dict[str, Any]) -> None:
        if "prices" in dct:
            pricedb = PriceDBs(self.ctx)
            pricedb.load(dct["prices"])
            self["prices"] = pricedb

        conversions = []
        if "conversions" in dct:
            for convdct in dct["conversions"]:
                conversion = Conversion(self.ctx, convdct)
                conversions.append(conversion)
        self["conversions"] = conversions

    def dfs(
        self,
        conversions: Dict[str, Dict[str, float]],
        startname: str,
        endname: str,
        visited: Optional[List[str]] = None,
    ) -> Optional[List[str]]:
        if visited is None:
            visited = []
        if startname == endname:
            return visited + [startname]

        for nextunit in conversions[startname].keys():
            if nextunit in visited:
                continue
            result = self.dfs(conversions, nextunit, endname, visited + [startname])
            if result is not None:
                return result
        return None

    def convert(self, amount: float, unitfrom: Unit, unitto: Unit) -> Optional[float]:
        conversions: Dict[str, Dict[str, float]] = defaultdict(dict)
        # construct node tree
        convs = self["conversions"]
        for unit in self.ctx.units.units:
            convs += unit["conversions"]
        for conv in convs:
            fromname = conv["from"]["name"]
            toname = conv["to"]["name"]
            conversions[fromname][toname] = conv["ratio"]
            conversions[toname][fromname] = 1 / conv["ratio"]

        # find path between conversions
        path = self.dfs(conversions, unitfrom["name"], unitto["name"])
        if path is None:
            self.ctx.issues.warn(
                ISSUE_KNOWN_PRICE_UNKNOWN_CONVERSION,
                f'{self["name"]} has a known price, but conversion {unitfrom["name"]} -> {unitto["name"]} not known',
            )
            return None
        assert len(path) != 0
        oldelem = path[0]
        for elem in path[1:]:
            amount *= conversions[oldelem][elem]
            oldelem = elem
        return amount

    def getprice(self, amount: float, unit: Unit) -> Optional[float]:
        if "prices" not in self.dct:
            return None
        prices: List[float] = []
        pricedbs: PriceDBs = self["prices"]
        for entry in pricedbs.pricedbs:
            assert isinstance(entry, PriceDB)
            entryamount: float = entry["amount"]
            entryprice: float = entry["price"]
            entryunit: Unit = entry["unit"]
            if entryunit == unit:
                prices.append((amount / entryamount) * entryprice)
            else:
                newamount = self.convert(amount, unit, entryunit)
                if newamount is not None:
                    prices.append((newamount / entryamount) * entryprice)
        if len(prices) == 0:
            return None
        assert len(prices) == 1
        return prices[0]


class AIngredients:
    def __init__(self, ctx: Context) -> None:
        self.ctx = ctx
        self.ingredients: List[Ingredient] = []

    def load(self, lst: List[Any]) -> None:
        pass

    @abstractmethod
    def get(self, name: str) -> Optional[Ingredient]:
        ...


class FakeIngredients(AIngredients):
    def get(self, name: str) -> Optional[Ingredient]:
        for ing in self.ingredients:
            if ing["name"] == name:
                return ing
        ing = Ingredient(self.ctx, {"name": name})
        self.ingredients.append(ing)
        return ing


class Ingredients(AIngredients):
    def load(self, lst: List[Any]) -> None:
        for ingdct in lst:
            ing = Ingredient(self.ctx, ingdct)
            self.ingredients.append(ing)

    def get(self, name: str) -> Optional[Ingredient]:
        for ing in self.ingredients:
            if ing["name"] == name or "aliases" in ing and name in ing["aliases"]:
                return ing
        return None


class PriceDBs:
    def __init__(self, ctx: Context) -> None:
        self.ctx = ctx
        self.pricedbs: List[PriceDB] = []

    def load(self, lst: List[Any]) -> None:
        for elem in lst:
            pricedb = PriceDB(self.ctx, elem)
            self.pricedbs.append(pricedb)

    def __repr__(self) -> str:
        return self.pricedbs.__repr__()


class PriceDB(Element):
    def load(self, dct: Dict[str, Any]) -> None:
        if "amount" not in dct:
            self["amount"] = 1.0

        if "unit" in dct:
            unitstr = dct["unit"]
            self["unit"] = self.ctx.units.get(unitstr)
            if self["unit"] is None:
                self.ctx.issues.error(ISSUE_UNKNOWN_UNIT, f"unknown unit {unitstr}")
        else:
            self["unit"] = self.ctx.default_unit


class IngredientInstance(Element):
    def load(self, dct: Dict[str, Any]) -> None:
        ingredient = self.ctx.ingredients.get(dct["name"])
        if ingredient is None:
            self.ctx.issues.error(
                ISSUE_UNKNOWN_INGREDIENT, f"unknown ingredient {dct['name']}"
            )
        self["ingredient"] = ingredient

        if "amount" not in dct:
            self["amount"] = 1.0

        if "unit" in dct:
            unitstr = dct["unit"]
            self["unit"] = self.ctx.units.get(unitstr)
            if self["unit"] is None:
                self.ctx.issues.error(ISSUE_UNKNOWN_UNIT, "unknown unit {unitstr}")
        else:
            self["unit"] = self.ctx.default_unit

        if "note" not in dct:
            self["note"] = ""

        alternatives = []
        if "or" in dct:
            for ingdct in dct["or"]:
                ing = IngredientInstance(self.ctx, ingdct)
                alternatives.append(ing)
        self["alternatives"] = alternatives

        if ingredient is not None:
            self["price"] = ingredient.getprice(self["amount"], self["unit"])


class Recipe(Element):
    def __init__(self, ctx: Context, dct: Dict[str, Any]) -> None:
        super().__init__(ctx, dct)
        self.srcpath = ""
        self.outpath = ""

    def load(self, dct: Dict[str, Any]) -> None:
        ingredients: List[IngredientInstance] = []
        if "ingredients" in dct:
            for ing in dct["ingredients"]:
                ingredient = IngredientInstance(self.ctx, ing)
                ingredients.append(ingredient)
        self["ingredients"] = ingredients

        subrecipes: List[Recipe] = []
        if "subrecipes" in dct:
            for partdct in dct["subrecipes"]:
                rp = Recipe(self.ctx, partdct)
                subrecipes.append(rp)
        self["subrecipes"] = subrecipes

        price: Optional[int] = 0
        ingswithprice = 0
        ingswithoutprice = 0
        for ing in ingredients:
            if ing["price"] is None:
                ingswithoutprice += 1
                continue
            ingswithprice += 1
            price += ing["price"]
        if ingswithoutprice != 0 or len(ingredients) == 0:
            price = None
        self["price"] = price


class Builder:
    def __init__(self) -> None:
        self.jinjaenv = jinja2.Environment(
            loader=jinja2.FileSystemLoader("templates"),
            autoescape=jinja2.select_autoescape(),
        )

        def numprint(input: int) -> str:
            out = str(input)
            if out.endswith(".0"):
                return out.split(".", maxsplit=1)[0]
            return out

        def amountprint(input: int) -> str:
            out = numprint(input)
            if out == "0.5":
                return "1/2"
            if out == "0.25":
                return "1/4"
            if out == "0.75":
                return "3/4"
            return out

        self.jinjaenv.filters["numprint"] = numprint
        self.jinjaenv.filters["amountprint"] = amountprint
        self.ctx = Context()
        # list of output files that will be built
        self.outfiles: Set[str] = set()

    def load_file(self, file: str) -> Any:
        print(f"loading {file}")
        with open(file, encoding="utf-8") as f:
            txt = f.read()
        if file.endswith(".json"):
            return json.loads(txt)
        return yaml.safe_load(txt)

    def rendertemplate(
        self, templatepath: str, format: str, file: str, dir: str, args: Any
    ) -> None:
        template = self.jinjaenv.get_template(templatepath)
        print(f"rendering {file}")
        outstr = template.render(args)

        os.makedirs(f"{dir}/out/{format}", exist_ok=True)

        with open(f"{dir}/out/{format}/{file}", "w", encoding="utf-8") as f:
            f.write(outstr)
        self.outfiles.add(file)

    def load(self, dir: str) -> int:
        if os.path.isfile(dir + "/units.yaml"):
            unitsschema = self.load_file("schemas/units.json")
            unitsdct = self.load_file(dir + "/units.yaml")
            self.ctx.load_units(unitsdct, unitsschema)
            retcode = self.ctx.issues.check()
            if retcode != 0:
                return 1

        if os.path.isfile(dir + "/ingredients.yaml"):
            ingredientsdct = self.load_file(dir + "/ingredients.yaml")
            ingredientsschema = self.load_file("schemas/ingredients.json")
            self.ctx.load_ingredients(ingredientsdct, ingredientsschema)
            retcode = self.ctx.issues.check()
            if retcode != 0:
                return 1
        return 0

    def run(self, dir: str) -> int:
        files = []
        for _, _, filesx in os.walk(dir + "/recipes"):
            files = filesx
            files.sort()

        recipes: List[Recipe] = []
        recipeschema = self.load_file("schemas/recipe.json")
        for file in files:
            if not file.endswith(".yaml"):
                print(f"unknown extension of {file}")
                continue
            recipedct = self.load_file(dir + "/recipes/" + file)
            jsonschema.validate(instance=recipedct, schema=recipeschema)
            recipe = Recipe(self.ctx, recipedct)
            recipe.srcpath = file
            recipe.outpath = file[:-5] + ".html"
            if self.ctx.issues.check() != 0:
                continue
            recipes.append(recipe)

        retcode = self.ctx.issues.check()
        if retcode != 0:
            return 1

        self.rendertemplate(
            templatepath="index.html",
            format="html",
            file="index.html",
            dir=dir,
            args={"recipes": recipes},
        )
        for recipe in recipes:
            self.rendertemplate(
                templatepath="recipe.html",
                format="html",
                file=recipe.outpath,
                dir=dir,
                args={"recipe": recipe},
            )
        return 0

    def finish(self, dir: str) -> int:
        files = set()
        for _, _, filesx in os.walk(f"{dir}/out/html"):
            files = set(filesx)

        # files we did not generate, probably left by a previous run, but not valid anymore
        extra_files = files - self.outfiles
        for file in extra_files:
            print(f"removing obsolete {file}")
            os.remove(f"{dir}/out/html/{file}")
        return 0

    def build(self, path: str) -> int:
        fcs = [self.load, self.run, self.finish]
        for func in fcs:
            try:
                ret = func(path)
                if ret != 0:
                    return ret
            except jsonschema.exceptions.ValidationError as e:
                print("ERROR:", e)
                return 1
        return 0


def help() -> None:
    print(f"usage: {sys.argv[0]} build DIR - build pages in DIR/out")
    print(f"       {sys.argv[0]} -h        - show help")


def main() -> None:
    if len(sys.argv) == 2 and sys.argv[1] == "-h":
        help()
        sys.exit(0)
    elif len(sys.argv) == 3 and sys.argv[1] == "build":
        ret = Builder().build(sys.argv[2])
        sys.exit(ret)
    else:
        help()
        sys.exit(1)


if __name__ == "__main__":
    main()