import { Stack, useTheme } from '@mui/material';
import { AxisBottom, AxisLeft, AxisRight, AxisTop } from '@visx/axis';
import { Grid } from '@visx/grid';
import { Group } from '@visx/group';
import { scaleBand, scaleLinear, scaleOrdinal, StringLike } from '@visx/scale';
import { AreaStack } from '@visx/shape';
import { SeriesPoint, StackKey } from '@visx/shape/lib/types';
import { defaultStyles, Tooltip, useTooltip } from '@visx/tooltip';
import { curveBasis, Series } from 'd3-shape';
import { useCallback, useMemo } from 'react';

import { Colors } from './constants';
import { Legend } from './Legend';
import { StackedAreaChartProps, TooltipData, ValueFormatter } from './types';
import { getChartContainerDirection, getChartDimensions } from './utils';

const defaultMargin = { top: 40, left: 50, right: 40, bottom: 100 };
const defaultValueFormat: ValueFormatter = (value, key) => `${key}: ${value}`;

export function StackedAreaChart<T extends Record<string, string | number>>({
  axisBottom = false,
  axisColor = Colors.BLUE,
  axisLeft = true,
  axisRight = true,
  axisTop = false,
  areaColors = [Colors.PURPLE, Colors.RED, Colors.ORANGE, Colors.YELLOW],
  data,
  gridColor = Colors.BLACK,
  gridOpacity = 0.07,
  height,
  legendAlign = 'center',
  legendContainerSize = 0.2,
  legendDirection = 'column',
  legendFontSize = 12,
  legendIconShape = 'circle',
  legendJustify = 'center',
  legendPosition = 'left',
  margin = defaultMargin,
  onClickHandler,
  showGrid = false,
  showLegends = false,
  tooltipBackgroundColor = Colors.BLACK,
  tooltips = true,
  tooltipTextColor = Colors.WHITE,
  valueFormat = defaultValueFormat,
  width,
  xKey,
  yKeys,
}: StackedAreaChartProps<T>) {
  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip<TooltipData>();
  const [
    svgWidth,
    svgHeight,
    chartWidth,
    chartHeight,
    legendContainerWidth,
    legendContainerHeight,
  ] = getChartDimensions(
    width,
    height,
    margin,
    showLegends,
    legendPosition,
    legendContainerSize
  );

  const theme = useTheme();

  const tooltipStyles = {
    ...defaultStyles,
    minWidth: 60,
    backgroundColor: tooltipBackgroundColor,
    color: tooltipTextColor,
    zIndex: theme.zIndex.tooltip,
  };

  const getX = useCallback((d: T): StringLike => d[xKey], [xKey]);
  const getY0 = (d: SeriesPoint<T>) => d[0];
  const getY1 = (d: SeriesPoint<T>) => d[1];

  const allValueTotals = data.reduce((allValueTotals, item) => {
    const valueTotal = yKeys.reduce((total, k) => {
      return (total += Number(item[k] ?? 0));
    }, 0);
    allValueTotals.push(valueTotal);
    return allValueTotals;
  }, [] as number[]);

  // scales, memoize for performance
  const yScale = useMemo(
    () =>
      scaleLinear<number>({
        range: [chartHeight, 0],
        domain: [0, Math.max(1, ...allValueTotals)],
      }),
    [allValueTotals, chartHeight]
  );
  const xScale = useMemo(
    () =>
      scaleBand<StringLike>({
        range: [0, chartWidth],
        round: true,
        domain: data.map(getX),
        paddingInner: 1,
        paddingOuter: 0,
      }),
    [chartWidth, data, getX]
  );

  const colorScale = scaleOrdinal<string | number, string>({
    domain: yKeys as string[],
    range: areaColors,
  });

  const bandWidth = xScale.domain().length
    ? chartWidth / xScale.domain().length
    : 0;

  const getClickPoint = useCallback(
    (
      stack: Series<T, StackKey>,
      evt: React.MouseEvent<SVGPathElement, MouseEvent>
    ) => {
      const rect = evt.currentTarget.getBoundingClientRect();
      const mouseX = evt.clientX - rect.left;
      const mouseY = evt.clientY - rect.top;

      const bandIndex = bandWidth
        ? Math.floor(Math.min(mouseX, chartWidth - 1) / bandWidth)
        : 0;

      const point = stack[bandIndex];
      return { mouseX, mouseY, point };
    },
    [bandWidth, chartWidth]
  );

  // HACK: Use mouse position to figure out which segment to show tooltip for
  const openTooltip = (
    stack: Series<T, StackKey>,
    evt: React.MouseEvent<SVGPathElement, MouseEvent>
  ) => {
    const { mouseX, mouseY, point } = getClickPoint(stack, evt);
    if (!point) return;

    showTooltip({
      tooltipData: {
        key: `${getX(point.data)}`,
        value: valueFormat(point.data[stack.key], stack.key),
      },
      tooltipTop: Math.min(mouseY, chartHeight - margin.bottom - 40),
      tooltipLeft: Math.min(mouseX, chartWidth - margin.right - 100),
    });
  };

  const handleClick = (
    stack: Series<T, StackKey>,
    evt: React.MouseEvent<SVGPathElement, MouseEvent>
  ) => {
    const { point } = getClickPoint(stack, evt);

    if (!point || !onClickHandler) return;

    onClickHandler(point.data, stack.key);
  };

  return width < 10 ? null : (
    <Stack
      direction={getChartContainerDirection(legendPosition)}
      justifyContent="space-between"
      height="100%"
      maxWidth="fit-content"
      position="relative"
    >
      <svg width={svgWidth} height={svgHeight}>
        {showGrid && (
          <Grid
            top={margin.top}
            left={margin.left}
            xScale={xScale}
            yScale={yScale}
            width={chartWidth}
            height={chartHeight}
            stroke={gridColor}
            strokeOpacity={gridOpacity}
          />
        )}
        <Group top={margin.top} left={margin.left}>
          <defs>
            {/**
             * Clip path for the seggment overlay lines.
             */}
            <clipPath id="stacked-area-chart">
              <AreaStack<T>
                top={margin.top}
                left={margin.left}
                keys={yKeys}
                data={data}
                curve={curveBasis}
                x={(d) => xScale(getX(d.data)) ?? 0}
                y0={(d) => yScale(getY0(d)) ?? 0}
                y1={(d) => yScale(getY1(d)) ?? 0}
              >
                {({ stacks, path }) =>
                  stacks.map((stack) => (
                    <path
                      key={`stack-${stack.key}`}
                      d={path(stack) || ''}
                      stroke={colorScale(stack.key) ?? ''}
                      fill={colorScale(stack.key) ?? ''}
                      fillOpacity={0.8}
                    />
                  ))
                }
              </AreaStack>
            </clipPath>
          </defs>
          {/* The actual chart that is displayed. */}
          <AreaStack<T>
            top={margin.top}
            left={margin.left}
            keys={yKeys}
            data={data}
            curve={curveBasis}
            x={(d) => xScale(getX(d.data)) ?? 0}
            y0={(d) => yScale(getY0(d)) ?? 0}
            y1={(d) => yScale(getY1(d)) ?? 0}
          >
            {({ stacks, path }) =>
              stacks.map((stack) => (
                <path
                  key={`stack-${stack.key}`}
                  d={path(stack) || ''}
                  stroke={colorScale(stack.key) ?? ''}
                  fill={colorScale(stack.key) ?? ''}
                  onClick={
                    onClickHandler ? (e) => handleClick(stack, e) : undefined
                  }
                  onMouseLeave={
                    tooltips
                      ? () => {
                          hideTooltip();
                        }
                      : undefined
                  }
                  onMouseMove={
                    tooltips ? (e) => openTooltip(stack, e) : undefined
                  }
                  style={{ cursor: 'pointer' }}
                />
              ))
            }
          </AreaStack>
          {/* Overlay lines for each segment. */}
          {data.slice(1).map((d, index) => (
            <line
              key={`overlay-line-${getX(d)}`}
              x1={bandWidth * (index + 1)}
              x2={bandWidth * (index + 1)}
              y1={chartHeight}
              y2={0}
              stroke={Colors.WHITE}
              strokeWidth={2}
              clipPath='url("#stacked-area-chart")'
            />
          ))}
        </Group>
        {axisTop && (
          <AxisTop
            left={margin.left}
            top={margin.top}
            scale={xScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
              dy: '-0.33em',
            })}
          />
        )}
        {axisBottom && (
          <AxisBottom
            left={margin.left}
            top={chartHeight + margin.top}
            scale={xScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'middle',
            })}
          />
        )}
        {axisLeft && (
          <AxisLeft
            left={margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'end',
              dy: '0.33em',
            })}
          />
        )}
        {axisRight && (
          <AxisRight
            left={chartWidth + margin.left}
            top={margin.top}
            scale={yScale}
            stroke={axisColor}
            tickStroke={axisColor}
            tickLabelProps={() => ({
              fill: axisColor,
              fontSize: 11,
              textAnchor: 'start',
              dy: '0.33em',
            })}
          />
        )}
      </svg>
      {showLegends && (
        <Legend
          legendPosition={legendPosition}
          legendDirection={legendDirection}
          legendJustify={legendJustify}
          legendAlign={legendAlign}
          legendFontSize={legendFontSize}
          legendIconShape={legendIconShape}
          maxWidth={legendContainerWidth}
          maxHeight={legendContainerHeight}
          margin={margin}
          colorScale={colorScale}
        />
      )}
      {tooltips && tooltipOpen && tooltipData && (
        <Tooltip top={tooltipTop} left={tooltipLeft} style={tooltipStyles}>
          <div>
            <strong>{tooltipData.key}</strong>
          </div>
          <div>{tooltipData.value}</div>
        </Tooltip>
      )}
    </Stack>
  );
}
