// FPGraphPage.js
import React, { useMemo, useRef, useContext } from "react";
import { Canvas } from "@react-three/fiber";
import { OrbitControls, Html } from "@react-three/drei";
import * as THREE from "three";
import { createNoise3D } from "simplex-noise";
import { useDataset } from "../components/DataFetcher";
import * as d3 from "d3-force";
import { DimensionContext } from '../components/ResponsiveWrapper';

// Color palette
const colorPalette = {
  dkpurple: "#2c294b",
  mdpurple: "#3b3484",
  ltpurple: "#7174b0",
  dkorange: "#de6736",
  mdorange: "#e08a3c",
  ltorange: "#ebb844",
  cranberry: "#762861",
  magenta: "#c5316a",
  mdgreen: "#80ba55",
};

function parseItemList(itemList) {
  if (!itemList) return [];
  return Array.isArray(itemList) ? itemList.flat(Infinity) : [];
}

function generateBaseNoise(width, depth, simplex, scale = 8, baseAmplitude = 0.5) {
  const heightmap = new Float32Array(width * depth);
  for (let i = 0; i < width * depth; i++) {
    const x = (i % width) / width;
    const z = Math.floor(i / width) / depth;
    heightmap[i] = simplex(x * scale, z * scale, 0) * baseAmplitude;
  }
  return heightmap;
}

function applyPeaks(
  width,
  depth,
  heightmap,
  itemPositions,
  normalizePositionFn,
  sigmaPeak = 0.08,
  amplitudeScale = 8.0
) {

  const maxPossibleAmplitude = 80000; // Since we normalized to 100 per metric and we're adding 4 metrics

  itemPositions.forEach(({ item, position, metrics }) => {
    const rawAmplitude =
      (metrics.support || 0) +
      (metrics.lift || 0) +
      (metrics.conviction || 0) +
      (metrics.zhangs_metric || 0);

    if (rawAmplitude <= 0) return;

    const [ix, iz] = normalizePositionFn(position, width);
    const scaledAmplitude = (rawAmplitude / maxPossibleAmplitude) * amplitudeScale * 40;

    for (let i = 0; i < width * depth; i++) {
      const x = (i % width) / width;
      const z = Math.floor(i / width) / depth;

      const dx = x - ix / width;
      const dz = z - iz / depth;
      const distSq = dx * dx + dz * dz;

      const influence =
        Math.exp(-distSq / (2 * sigmaPeak * sigmaPeak)) * scaledAmplitude;
      heightmap[i] = Math.max(heightmap[i], heightmap[i] + influence);
    }
  });
  return heightmap;
}

function distanceAndProjectionToLineSegment(px, pz, ax, az, cx, cz) {
  const acx = cx - ax;
  const acz = cz - az;
  const apx = px - ax;
  const apz = pz - az;
  const acLengthSq = acx * acx + acz * acz;

  let dist, t;
  if (acLengthSq === 0) {
    // A and C are essentially the same point
    const dx = px - ax;
    const dz = pz - az;
    dist = Math.sqrt(dx * dx + dz * dz);
    t = 0;
  } else {
    t = (apx * acx + apz * acz) / acLengthSq;
    if (t < 0) {
      const dx = px - ax;
      const dz = pz - az;
      dist = Math.sqrt(dx * dx + dz * dz);
      t = 0;
    } else if (t > 1) {
      const dx = px - cx;
      const dz = pz - cz;
      dist = Math.sqrt(dx * dx + dz * dz);
      t = 1;
    } else {
      const projX = ax + t * acx;
      const projZ = az + t * acz;
      const dx = px - projX;
      const dz = pz - projZ;
      dist = Math.sqrt(dx * dx + dz * dz);
    }
  }
  return { dist, t };
}

function applyLineRidges(
  width,
  depth,
  heightmap,
  parsedData,
  itemPositions,
  normalizePositionFn,
  sigmaRidge,
  ridgeScale,
  baseline,
  adjustedXExtent,
  adjustedYExtent
) {
  const ridgeMap = new Float32Array(width * depth).fill(0);

  parsedData.forEach((rule) => {
    if (!rule.consequents || rule.consequents.length === 0) return;
    const consequentItem = rule.consequents[0];
    const consequentPosData = itemPositions.find((p) => p.item === consequentItem);

    if (!consequentPosData) return;

    const cPos = normalizePositionFn(consequentPosData.position, width);

    const supportInfluence = ((rule.support || 0) / 100) * ridgeScale * 100;

    if (!rule.antecedents || rule.antecedents.length === 0) return;

    rule.antecedents.forEach((antItem) => {
      const antPosData = itemPositions.find((p) => p.item === antItem);
      if (!antPosData) return;

      const aPos = normalizePositionFn(antPosData.position, width);

      const ax = aPos[0] / width;
      const az = aPos[1] / depth;
      const cx = cPos[0] / width;
      const cz = cPos[1] / depth;

      for (let i = 0; i < width * depth; i++) {
        const ix = i % width;
        const iz = Math.floor(i / width);
        const x = ix / width;
        const z = iz / depth;

        const { dist } = distanceAndProjectionToLineSegment(
          x,
          z,
          ax,
          az,
          cx,
          cz
        );
        const influence =
          Math.exp(-(dist * dist) / (2 * sigmaRidge * sigmaRidge)) *
          supportInfluence;

        if (influence > ridgeMap[i]) {
          ridgeMap[i] = influence;
        }
      }
    });
  });

  for (let i = 0; i < width * depth; i++) {
    const baseVal = baseline[i];
    const currentHeight = heightmap[i]; //this is unused
    const ridgeVal = ridgeMap[i];
    const peakHeight = baseVal + ridgeScale;

    if (ridgeVal > 0) {
      const ridgeHeight = baseVal + ridgeVal;
      if (ridgeHeight <= peakHeight && heightmap[i] < ridgeHeight) {
        heightmap[i] = ridgeHeight;
      }
    }
  }
  return heightmap;
}


function arrangeItems(uniqueItems, parsedData) {
  const nodes = uniqueItems.map((item) => ({ id: item }));
  const links = parsedData.flatMap((rule) =>
    rule.antecedents.map((antecedent) => ({
      source: antecedent,
      target: rule.consequents[0],
      value: rule.support,
    }))
  );

  const simulation = d3.forceSimulation(nodes)
    .force("link", d3.forceLink(links).id((d) => d.id).strength(0.8))
    .force("charge", d3.forceManyBody().strength(-100))
    .force("center", d3.forceCenter(0, 0))
    .force("collision", d3.forceCollide(15));

  for (let i = 0; i < 300; i++) simulation.tick();

  // Compute extents
  const xValues = nodes.map((node) => node.x);
  const yValues = nodes.map((node) => node.y);
  const xMin = Math.min(...xValues);
  const xMax = Math.max(...xValues);
  const yMin = Math.min(...yValues);
  const yMax = Math.max(...yValues);

  const bufferFactor = 0.15; // Adjust this to control the size of the buffer (10% in this case)
const xBuffer = (xMax - xMin) * bufferFactor;
const yBuffer = (yMax - yMin) * bufferFactor;

const adjustedXExtent = [xMin - xBuffer, xMax + xBuffer];
const adjustedYExtent = [yMin - yBuffer, yMax + yBuffer];

  // Return positions and extents
  return {
    itemPositions: nodes.map((node) => ({
      item: node.id,
      position: [node.x, node.y],
    })),
    adjustedXExtent,
    adjustedYExtent,
    // xExtent: [xMin, xMax],
    // yExtent: [yMin, yMax],
  };
}

function distributeMetrics(parsedData, itemMetrics) {
  parsedData.forEach((rule) => {
    const allItems = [...rule.antecedents, ...rule.consequents];
    const totalItems = allItems.length;

    allItems.forEach((item) => {
      if (!itemMetrics[item]) return;
      const weight = totalItems > 0 ? 1 / totalItems : 0;

      itemMetrics[item].support += (rule.support || 0) * weight;
      itemMetrics[item].lift += (rule.lift || 0) * weight;
      itemMetrics[item].conviction += (rule.conviction || 0) * weight;
      itemMetrics[item].zhangs_metric += (rule.zhangs_metric || 0) * weight;
    });
  });

   // Add normalization step
  // Find maximum values for each metric
  const maxValues = {
    support: 0,
    lift: 0,
    conviction: 0,
    zhangs_metric: 0
  };

  // Find the maximum value for each metric
  Object.values(itemMetrics).forEach(metrics => {
    maxValues.support = Math.max(maxValues.support, metrics.support);
    maxValues.lift = Math.max(maxValues.lift, metrics.lift);
    maxValues.conviction = Math.max(maxValues.conviction, metrics.conviction);
    maxValues.zhangs_metric = Math.max(maxValues.zhangs_metric, metrics.zhangs_metric);
  });

  // Normalize each metric to a consistent scale (e.g., 0-100)
  const TARGET_MAX = 10000; // Increased from 100 to 1000
  const MINIMUM_SCALE = 50; // Ensure even small values get scaled up
  Object.values(itemMetrics).forEach(metrics => {
    // Scale up instead of down, with a minimum scale factor
    metrics.support = Math.max(metrics.support * (TARGET_MAX / Math.max(maxValues.support, 1)), metrics.support * MINIMUM_SCALE);
    metrics.lift = Math.max(metrics.lift * (TARGET_MAX / Math.max(maxValues.lift, 1)), metrics.lift * MINIMUM_SCALE);
    metrics.conviction = Math.max(metrics.conviction * (TARGET_MAX / Math.max(maxValues.conviction, 1)), metrics.conviction * MINIMUM_SCALE);
    metrics.zhangs_metric = Math.max(metrics.zhangs_metric * (TARGET_MAX / Math.max(maxValues.zhangs_metric, 1)), metrics.zhangs_metric * MINIMUM_SCALE);
  });
}

function normalizePosition(position, gridSize, adjustedXExtent, adjustedYExtent) {
  const [x, y] = position;
  const xMin = adjustedXExtent[0];
  const xMax = adjustedXExtent[1];
  const yMin = adjustedYExtent[0];
  const yMax = adjustedYExtent[1];

  const normalizedX = Math.round(((x - xMin) / (xMax - xMin)) * (gridSize - 1));
  const normalizedY = Math.round(((y - yMin) / (yMax - yMin)) * (gridSize - 1));

  return [normalizedX, normalizedY];
}

function mapToTerrainCoords(coord, coordMin, coordMax, terrainSize) {
  const newMin = -terrainSize / 2;
  const newMax = terrainSize / 2;
  return ((coord - coordMin) / (coordMax - coordMin)) * (newMax - newMin) + newMin;
}

function smoothHeightmap(heightmap, width, depth, iterations = 2) {
  const smoothedHeightmap = new Float32Array(heightmap);

  for (let iter = 0; iter < iterations; iter++) {
    for (let z = 1; z < depth - 1; z++) {
      for (let x = 1; x < width - 1; x++) {
        const index = z * width + x;
        const neighbors = [
          heightmap[index - 1],
          heightmap[index + 1],
          heightmap[index - width],
          heightmap[index + width],
          heightmap[index - width - 1],  // Adding diagonal neighbors
          heightmap[index - width + 1],
          heightmap[index + width - 1],
          heightmap[index + width + 1]
        ];
        const average = (heightmap[index] * 3 + neighbors.reduce((a, b) => a + b, 0)) / (neighbors.length + 3);
        smoothedHeightmap[index] = average;
      }
    }
    heightmap.set(smoothedHeightmap);
  }
  return heightmap;
}


function generateHeightmapWithPeaksAndRidges(
  width,
  depth,
  itemPositions,
  parsedData,
  itemMetrics,
  adjustedXExtent,
  adjustedYExtent
) {
  const simplex = createNoise3D();

  let heightmap = generateBaseNoise(width, depth, simplex, 5, 0.25);

  heightmap = applyPeaks(
    width,
    depth,
    heightmap,
    itemPositions,
    (pos, gridSize) => normalizePosition(pos, gridSize, adjustedXExtent, adjustedYExtent),
    0.05,
    4.0
  );

    // Add normalization here:
    const maxHeight = Math.max(...heightmap);
    const desiredMaxHeight = 80; // Adjust this value based on your needs (try 70-100)
    if (maxHeight > 0) {
      for (let i = 0; i < heightmap.length; i++) {
        heightmap[i] = (heightmap[i] / maxHeight) * desiredMaxHeight;
      }
    }

  const postPeaksHeightmap = heightmap.slice(); // Copy for baseline

  heightmap = applyLineRidges(
    width,
    depth,
    heightmap,
    parsedData,
    itemPositions,
    (pos, gridSize) => normalizePosition(pos, gridSize, adjustedXExtent, adjustedYExtent),
    0.007,
    80,
    postPeaksHeightmap,
    adjustedXExtent,
    adjustedYExtent
  );

  heightmap = smoothHeightmap(heightmap, width, depth, 8.0);

  return heightmap;
}

function Terrain({ itemPositions, parsedData, itemMetrics, adjustedXExtent, adjustedYExtent }) {
  const meshRef = useRef();
  const { width: containerWidth, height: containerHeight } = useContext(DimensionContext); //this is unused

  // Calculate terrain size based on data extents
  // const xRange = xExtent[1] - xExtent[0];
  // const yRange = yExtent[1] - yExtent[0];
  const xRange = adjustedXExtent[1] - adjustedXExtent[0];
  const yRange = adjustedYExtent[1] - adjustedYExtent[0];
  const dataRange = Math.max(xRange, yRange);
  const scaleFactor = 4.0; // Adjust as needed
  const size = dataRange * scaleFactor;

  // const divisions = Math.max(Math.floor(size / 5), 50);
  const divisions = Math.max(Math.floor(size / 4), 200); // Increase the base resolution

  const [geometry, heightmap] = useMemo(() => {
    const geometry = new THREE.PlaneGeometry(size, size, divisions, divisions);
    geometry.rotateX(-Math.PI / 2);

    const width = divisions + 1;
    const depth = divisions + 1;

    const heightmap = generateHeightmapWithPeaksAndRidges(
      width,
      depth,
      itemPositions,
      parsedData,
      itemMetrics,
      adjustedXExtent,
      adjustedYExtent
    );

    const positionAttribute = geometry.attributes.position;
    const colorAttribute = new THREE.BufferAttribute(
      new Float32Array(positionAttribute.count * 3),
      3
    );

    for (let i = 0; i < positionAttribute.count; i++) {
      const elevation = heightmap[i];
      const vertex = new THREE.Vector3().fromBufferAttribute(
        positionAttribute,
        i
      );
      vertex.y = elevation;
      positionAttribute.setXYZ(i, vertex.x, vertex.y, vertex.z);

      const elevationNormalized = (elevation + 1) / 2;  //If artifacts are visible, tweak the range or smoothing logic.
      let color;
      if (elevationNormalized < 5) {
        color = new THREE.Color(colorPalette.cranberry);
      } else if (elevationNormalized < 10) {
        color = new THREE.Color(colorPalette.mdpurple);
      } else if (elevationNormalized < 15) {
        color = new THREE.Color(colorPalette.dkpurple);
      } else if (elevationNormalized < 20) {
        color = new THREE.Color(colorPalette.mdorange);
      } else if (elevationNormalized < 25) {
        color = new THREE.Color(colorPalette.dkorange);
      } else if (elevationNormalized < 30) {
        color = new THREE.Color(colorPalette.magenta);
      } else if (elevationNormalized < 35) {
        color = new THREE.Color(colorPalette.ltpurple);
      } else if (elevationNormalized < 40) {
        color = new THREE.Color(colorPalette.ltorange);
      } else {
        color = new THREE.Color(colorPalette.mdgreen);
      }

      colorAttribute.setXYZ(i, color.r, color.g, color.b);
    }

    geometry.setAttribute("color", colorAttribute);
    geometry.computeVertexNormals();
    console.log("[Terrain] Geometry and colors set up successfully.");

    return [geometry, heightmap];
  }, [
    itemPositions,
    parsedData,
    itemMetrics,
    divisions,
    size,
    adjustedXExtent,
    adjustedYExtent,
  ]);

  function getItemHeightForLabel(xPosOriginal, zPosOriginal) {
    const width = divisions + 1;

    // Map item coords to terrain coords using actual extents
    const xPos = mapToTerrainCoords(
      xPosOriginal,
      adjustedXExtent[0],
      adjustedXExtent[1],
      size
    );
    const zPos = mapToTerrainCoords(
      zPosOriginal,
      adjustedYExtent[0],
      adjustedYExtent[1],
      size
    );

    const [ix, iz] = normalizePosition(
      [xPosOriginal, zPosOriginal],
      width,
      adjustedXExtent,
      adjustedYExtent
    );

    const ixClamped = Math.max(0, Math.min(width - 1, ix));
    const izClamped = Math.max(0, Math.min(width - 1, iz));

    const i = izClamped * width + ixClamped;
    const h = heightmap[i] || 0;

    return { terrainX: xPos, terrainZ: zPos, height: h };
  }

  const labelPositions = useMemo(() => {
    return itemPositions.map(({ item, position }) => {
      const [origX, origZ] = position;
      const { terrainX, terrainZ, height: peakHeight } = getItemHeightForLabel(
        origX,
        origZ
      );
      const finalPos = [terrainX, peakHeight + 3, terrainZ];
      return { item, position: finalPos };
    });
  }, [itemPositions, heightmap, size, adjustedXExtent, adjustedYExtent]);

  return (
    <>
      <mesh ref={meshRef} geometry={geometry}>
        <meshStandardMaterial
          vertexColors={true}
          side={THREE.DoubleSide}
          flatShading={false}
        />
      </mesh>
      {labelPositions.map(({ item, position }) => (
        <Html
          key={item}
          position={position}
          style={{ pointerEvents: "none" }}
        >
          <div
            style={{
              color: "white",
              fontSize: "9px",
              fontFamily: "Manjari, sans-serif;",
              textAlign: "center",
              background: "rgba(0,0,0,0.7)",
              padding: "2px 3px",
              borderRadius: "2px",
              boxShadow: "0px 1px 2px rgba(0,0,0,0.3)",
            }}
          >
            {item}
          </div>
        </Html>
      ))}
    </>
  );
}

const Peaks = () => {
  const { data } = useDataset("fpgrowth");

  const fpgrowthData = useMemo(() => {
    if (!data?.FPGrowth) return null;
    const cleanedString = data.FPGrowth.replace(/Infinity/g, "2");
    try {
      return JSON.parse(cleanedString);
    } catch {
      console.error("Error parsing FPGrowth data.", cleanedString);
      return null;
    }
  }, [data]);

  const parsedData = useMemo(() => {
    if (!fpgrowthData || !fpgrowthData.data) return [];
    const { columns, data: rows } = fpgrowthData;
    return rows.map((row) => {
      const rowData = {};
      columns.forEach((col, index) => {
        rowData[col] = row[index];
      });
      const antecedents = parseItemList(rowData["antecedents"]);
      const consequents = parseItemList(rowData["consequents"]);
      return {
        antecedents,
        consequents,
        antecedent_support: rowData["antecedent support"] || 0,
        consequent_support: rowData["consequent support"] || 0,
        support: rowData["support"] || 0,
        confidence: rowData["confidence"] || 0,
        lift: rowData["lift"] || 0,
        conviction: rowData["conviction"] || 0,
        zhangs_metric: rowData["zhangs_metric"] || 0,
      };
    });
  }, [fpgrowthData]);

  const uniqueItems = useMemo(() => {
    const items = new Set();
    parsedData.forEach((rule) => {
      rule.antecedents.forEach((item) => items.add(item));
      rule.consequents.forEach((item) => items.add(item));
    });
    return Array.from(items);
  }, [parsedData]);

  const itemMetrics = useMemo(() => {
    const metrics = {};
    uniqueItems.forEach((item) => {
      metrics[item] = {
        support: 0,
        conviction: 0,
        zhangs_metric: 0,
        lift: 0,
        count: 0,
      };
    });
    distributeMetrics(parsedData, metrics);
    console.log("[Peaks] Item metrics (aggregated):", metrics);
    return metrics;
  }, [parsedData, uniqueItems]);

  const {
    itemPositions: arrangedItemPositions,
    adjustedXExtent,
    adjustedYExtent,
  } = useMemo(() => {
    const { itemPositions, adjustedXExtent, adjustedYExtent } = arrangeItems(
      uniqueItems,
      parsedData
    );
    // Map positions to include metrics
    const positions = itemPositions.map(({ item, position }) => ({
      item,
      position,
      metrics:
        itemMetrics[item] || {
          support: 0,
          lift: 0,
          conviction: 0,
          zhangs_metric: 0,
        },
    }));
    return {
      itemPositions: positions,
      adjustedXExtent,
      adjustedYExtent,
    };
  }, [uniqueItems, itemMetrics, parsedData]);


  console.log("[FPGraphPage] Render start");
  return (
    <div style={{ height: "100vh", width: "100vw" }}>
      {/* <Canvas camera={{ position: [0, 300, 300], fov: 75 }}> */}
      {/* <Canvas camera={{ position: [-150, 300, 250], fov: 125 }}> */}
      <Canvas camera={{ position: [-250, 300, 250], fov: 45 }} dpr={[1, 2]} antialias={true} >
      {/* <OrbitControls enableZoom={true} /> */}
      <OrbitControls enableZoom={true} minDistance={50} maxDistance={1000} />

        <ambientLight intensity={1.6} />
        <directionalLight position={[10, 30, 10]} intensity={2.8} />
        {parsedData && <Terrain itemPositions={arrangedItemPositions} parsedData={parsedData} itemMetrics={itemMetrics}            
          adjustedXExtent={adjustedXExtent}
          adjustedYExtent={adjustedYExtent} />}
      </Canvas>
    </div>
  );
};

export default Peaks;