import React, { useEffect, useMemo, useRef, useState } from "react";
import { Theme, useTheme } from "@mui/material";
import { ComputedNode } from "@nivo/network";
import { LocationChartNode } from "./locationNetworkChartUtils";
import { css } from "@emotion/react";
import cssComponentsStyles from "../../../Global/Styles/components";

const cssStyles = (theme: Theme) => ({
  textStyle: css({
    fill: theme.palette.text.primary,
    fontSize: theme.typography.caption.fontSize,
    fontWeight: theme.typography.caption.fontWeight,
    transform: "translate(0, -20px)",
  }),
  backgroundStyle: css({
    fill:
      theme.palette.mode === "light"
        ? theme.palette.grey["100"]
        : theme.palette.customColors.darkGray,
    opacity: 0.7,
  }),
});

interface NodeLabelsLayerProps {
  nodes: ComputedNode<LocationChartNode>[];
  visibleNodesLabels: string[];
}

interface NodeLabel {
  id: string;
  x: number;
  y: number;
  layerName: string | undefined;
}

const NodeLabelsLayer: React.FC<NodeLabelsLayerProps> = ({
  nodes,
  visibleNodesLabels,
}) => {
  const theme = useTheme();
  const styles = {
    ...cssStyles(theme),
    ...cssComponentsStyles(theme),
  };
  const [labelWidths, setLabelWidths] = useState<{ [key: string]: number }>({});
  const labelRefs = useRef<{ [key: string]: SVGTextElement | null }>({});

  useEffect(() => {
    const newLabelWidths: { [key: string]: number } = {};
    Object.keys(labelRefs.current).forEach((key) => {
      const textElement = labelRefs.current[key];
      if (textElement) {
        newLabelWidths[key] = textElement.getComputedTextLength();
      }
    });
    setLabelWidths(newLabelWidths);
  }, [nodes]);

  const labels = useMemo(
    () =>
      nodes
        .map((node): NodeLabel | null => {
          const visibleLayer = shouldDisplayNodeId(node.id, visibleNodesLabels);
          if (visibleLayer) {
            const layerName = node.id
              .split("/")
              .find((layer) => layer.startsWith(visibleLayer));
            return { id: node.id, x: node.x, y: node.y + 4, layerName };
          }
          return null;
        })
        .filter((node): node is NodeLabel => node !== null),
    [nodes, visibleNodesLabels]
  );

  return (
    <>
      {labels.map(({ id, x, y, layerName }) => (
        <>
          <rect
            x={x - (labelWidths[id] / 2 + 5)}
            y={y - 34}
            width={labelWidths[id] + 10}
            height={20}
            rx="8"
            ry="8"
            css={styles.backgroundStyle}
          />
          <text
            ref={(el) => {
              labelRefs.current[id] = el;
            }}
            key={id}
            x={x}
            y={y}
            textAnchor="middle"
            css={styles.textStyle}
          >
            {layerName}
          </text>
        </>
      ))}
    </>
  );
};

export default NodeLabelsLayer;

const shouldDisplayNodeId = (nodeId: string, visibleNodesLabels: string[]) => {
  const layers = nodeId.split("/");
  const lastLayer = layers[layers.length - 1];
  const visibleLayer = visibleNodesLabels.find((label) => lastLayer.startsWith(label));

  return visibleLayer || false;
};
