import { Expression } from "@hypertune/sdk/src/shared/types";
import {
  getEmptyLogs,
  mapExpressionWithResult,
} from "@hypertune/sdk/src/shared/helpers";
import getApplicationFunctionExpression from "../expression/getApplicationFunctionExpression";
import {
  ExpressionMap,
  ExpressionMapPointer,
  ExpressionMapValue,
} from "./types";

export default function getExpressionMap(
  startExpression: Expression | null
): ExpressionMap {
  if (!startExpression) {
    return {};
  }
  return mapExpressionWithResult<Record<string, ExpressionMapValue>>(
    (expr) => {
      if (!expr) {
        return {
          newExpression: expr,
          mapResult: {},
        };
      }
      const mapValueExpr = toExpressionMapValue(expr);
      if (!mapValueExpr) {
        return {
          newExpression: expr,
          mapResult: {},
        };
      }
      return {
        newExpression: expr,
        mapResult: {
          [expr.id]: {
            ...mapValueExpr,
            // Override these fields to ensure
            // they don't cause spurious differences.
            isTransient: false,
            logs: getEmptyLogs(),
          },
        },
      };
    },
    (...results) => {
      return results.reduce((prev, curr) => {
        return { ...prev, ...curr };
      }, {});
    },
    startExpression
  ).mapResult;
}

function toExpressionMapValue(
  expression: Expression
): ExpressionMapValue | null {
  switch (expression.type) {
    case "NoOpExpression":
    case "BooleanExpression":
    case "StringExpression":
    case "IntExpression":
    case "FloatExpression":
    case "RegexExpression":
    case "EnumExpression":
    case "VariableExpression":
      return expression;

    case "ObjectExpression":
      return {
        ...expression,
        type: "ObjectExpressionMapValue",
        fields: newRecordExpressionMapPointers(expression.fields),
      };
    case "GetFieldExpression":
      return {
        ...expression,
        type: "GetFieldExpressionMapValue",
        object: newExpressionMapPointer(expression.object),
      };
    case "UpdateObjectExpression":
      return {
        ...expression,
        type: "UpdateObjectExpressionMapValue",
        object: newExpressionMapPointer(expression.object),
        updates: newRecordExpressionMapPointers(expression.updates),
      };
    case "ListExpression":
      return {
        ...expression,
        type: "ListExpressionMapValue",
        items: Object.fromEntries(
          expression.items.map((itemExpr) => [
            itemExpr!.id,
            newExpressionMapPointer(itemExpr),
          ])
        ),
        itemsWeights: Object.fromEntries(
          expression.items.map((itemExpr, index) => [itemExpr!.id, index])
        ),
      };
    case "SwitchExpression":
      return {
        ...expression,
        type: "SwitchExpressionMapValue",
        control: newExpressionMapPointer(expression.control),
        default: newExpressionMapPointer(expression.default),
        cases: Object.fromEntries(
          expression.cases.map(({ id, when, then }) => [
            id,
            {
              when: newExpressionMapPointer(when),
              then: newExpressionMapPointer(then),
            },
          ])
        ),
        casesWeights: Object.fromEntries(
          expression.cases.map(({ id }, index) => [id, index])
        ),
      };
    case "EnumSwitchExpression":
      return {
        ...expression,
        type: "EnumSwitchExpressionMapValue",
        control: newExpressionMapPointer(expression.control),
        cases: newRecordExpressionMapPointers(expression.cases),
      };
    case "ArithmeticExpression":
      return {
        ...expression,
        type: "ArithmeticExpressionMapValue",
        a: newExpressionMapPointer(expression.a),
        b: newExpressionMapPointer(expression.b),
      };
    case "ComparisonExpression":
      return {
        ...expression,
        type: "ComparisonExpressionMapValue",
        a: newExpressionMapPointer(expression.a),
        b: newExpressionMapPointer(expression.b),
      };
    case "RoundNumberExpression":
      return {
        ...expression,
        type: "RoundNumberExpressionMapValue",
        number: newExpressionMapPointer(expression.number),
      };
    case "StringifyNumberExpression":
      return {
        ...expression,
        type: "StringifyNumberExpressionMapValue",
        number: newExpressionMapPointer(expression.number),
      };
    case "StringConcatExpression":
      return {
        ...expression,
        type: "StringConcatExpressionMapValue",
        strings: newExpressionMapPointer(expression.strings),
      };
    case "GetUrlQueryParameterExpression":
      return {
        ...expression,
        type: "GetUrlQueryParameterExpressionMapValue",
        url: newExpressionMapPointer(expression.url),
        queryParameterName: newExpressionMapPointer(
          expression.queryParameterName
        ),
      };
    case "SplitExpression":
      return {
        ...expression,
        type: "SplitExpressionMapValue",
        expose: newExpressionMapPointer(expression.expose),
        unitId: newExpressionMapPointer(expression.unitId),
        dimensionMapping:
          expression.dimensionMapping.type === "discrete"
            ? {
                ...expression.dimensionMapping,
                type: "ExpressionMapValueDiscreteDimensionMapping",
                cases: newRecordExpressionMapPointers(
                  expression.dimensionMapping.cases
                ),
              }
            : {
                ...expression.dimensionMapping,
                type: "ExpressionMapValueContinuousDimensionMapping",
                function: newExpressionMapPointer(
                  expression.dimensionMapping.function
                ),
              },
        eventPayload: newExpressionMapPointer(expression.eventPayload),
        featuresMapping: newRecordExpressionMapPointers(
          expression.featuresMapping
        ),
      };
    case "LogEventExpression":
      return {
        ...expression,
        type: "LogEventExpressionMapValue",
        unitId: newExpressionMapPointer(expression.unitId),
        eventPayload: newExpressionMapPointer(expression.eventPayload),
        featuresMapping: newRecordExpressionMapPointers(
          expression.featuresMapping
        ),
      };
    case "FunctionExpression":
      return {
        ...expression,
        type: "FunctionExpressionMapValue",
        body: newExpressionMapPointer(expression.body),
      };
    case "ApplicationExpression":
      return null;
    default: {
      const neverExpression: never = expression;
      throw new Error(`unexpected expression: ${neverExpression}`);
    }
  }
}

function newRecordExpressionMapPointers(
  exprRecord: Record<string, Expression | null>
): Record<string, ExpressionMapPointer> {
  return Object.fromEntries(
    Object.entries(exprRecord).map(([fieldName, expr]) => [
      fieldName,
      newExpressionMapPointer(expr),
    ])
  );
}

function newExpressionMapPointer(
  expression: Expression | null
): ExpressionMapPointer {
  if (!expression) {
    return null;
  }
  if (expression.type === "ApplicationExpression") {
    const functionExpression = getApplicationFunctionExpression(expression);
    if (!functionExpression) {
      return newExpressionMapPointer(expression.function);
    }
    const result = newExpressionMapPointer(functionExpression.body);

    if (!result || functionExpression.parameters.length === 0) {
      return null;
    }

    return {
      type: "ExpressionMapPointer",
      id: result.id,
      variablesExpressionData: {
        ...Object.fromEntries(
          functionExpression.parameters
            .map((param, index) => [
              param.id,
              {
                id: param.id,
                name: param.name,
                arg: newExpressionMapPointer(expression.arguments[index]),
                weight: index,
              },
            ])
            .concat(
              result.variablesExpressionData
                ? Object.entries(result.variablesExpressionData).map(
                    (entry) => [
                      entry[0],
                      {
                        ...entry[1],
                        weight: entry[1].weight + expression.arguments.length,
                      },
                    ]
                  )
                : []
            )
        ),
      },
    };
  }
  return { type: "ExpressionMapPointer", id: expression.id };
}
