import { useEffect, useMemo, useState } from "react";
import { useTooltip } from "@visx/tooltip";
import { scaleOrdinal } from "@visx/scale";
import { localPoint } from "@visx/event";
import { BarStack } from "@visx/shape";
import { Group } from "@visx/group";
import { LegendOrdinal } from "@visx/legend";
import { AxisRight } from "@visx/axis";
import { animated, useTrail } from "@react-spring/web";
import { ScaleBand } from "d3-scale";

// Components
import GridLines from "../GridLines";
import CustomDateTimeBottomAxis from "../CustomDateTimeBottomAxis";
import Tooltip from "../../widgets/Tooltip";

// Data and Definitions
import { darkGray, portalColors } from "../../../style/colors";
import { PRIMARY_FONT } from "../../../style/fonts";
import { MULTI_SERIES_CHART_PROPS } from "../../../types/charts/props";
import {
	OBSERVATION_WITH_DATE_LABEL,
	OBSERVATION_GENERIC,
	areObservationsDateTimeLabeled,
} from "../../../types/charts";

// Hooks
import useSVGDimensions from "../../../hooks/useSVGDimensions";
import useBandScale from "../../../hooks/useBandScale";
import useLinearScale from "../../../hooks/useLinearScale";
import useFlattenSeriesValues from "../../../data/hooks/useFlattenSeriesValues";
import useObserveChartComponents from "../../../data/hooks/useObserveChartComponents";
import ChartContainer from "../../ChartContainer";
import { DateTime } from "luxon";
import { isDateTime } from "../../../utils/datesAndTimes";

/**
 * Bar chart that can be used for single or multi-series data.
 *
 * If used for multiple series, the bars will stack vertically.
 * */
export default function StackedBarChart<
	ObservationType extends OBSERVATION_GENERIC,
>({
	margins = { vertical: 20, horizontal: 20 },
	observations,
	seriesNames,
	axes,
	dataColors,
	chartTitle,
	chartExplanation,
	colorTheme = portalColors.gray,
	displayLegend = false,
	transparentBackground = false,
	gridLines = "rows",
	labelFormatter = (label) => label.toString(),
	dataFormatter = (data) => data.toString(),
	tooltip,
	dataRefreshing,
}: MULTI_SERIES_CHART_PROPS<ObservationType>) {
	const [currentObservationIndex, setCurrentObservationIndex] = useState<
		number | null
	>(null);

	const observationsDateLabeled = areObservationsDateTimeLabeled(observations);

	const {
		observeComponent,
		componentWidth,
		componentHeight,
		observeTitle,
		observeLegend,
		titleHeight,
		legendHeight,
	} = useObserveChartComponents();

	// Determine the effective dimensions for the actual SVG element.
	const svgDimensions = useSVGDimensions({
		width: componentWidth,
		height: componentHeight,
		margins,
		heightSubstractors: [titleHeight, legendHeight],
	});

	// scales, memoize for performance
	const xScale = useBandScale({
		domain: observations.map((observation) => observation.observationLabel),
		range: [0, svgDimensions[0] - 20],
	});

	const allSeriesValues = useFlattenSeriesValues(
		observations,
		seriesNames,
	) as number[];

	const yScale = useLinearScale({
		domain: [0, Math.max(...allSeriesValues)],
		range: [svgDimensions[1] - (axes?.observationLabel ? 40 : 0), 0],
	});

	const colorScale = useMemo(
		() =>
			scaleOrdinal({
				domain: seriesNames,
				range: dataColors,
			}),
		[dataColors, seriesNames],
	);

	const [trails, trailsAPI] = useTrail(observations.length, () => ({
		from: { scaleY: 0, opacity: 0, filter: "saturate(0%)" },
		config: { tension: 500, friction: 30 },
	}));

	// The `BarStack` component requires data in a specific format.
	const visxDataTransformation = useMemo<
		Array<
			{
				observationLabel: ObservationType["observationLabel"];
			} & ObservationType["data"]
		>
	>(
		() =>
			observations.map((observation) => ({
				observationLabel: observation.observationLabel,
				...observation.data,
			})),
		[observations],
	);

	useEffect(() => {
		dataRefreshing
			? trailsAPI.start({
					to: { scaleY: 0, opacity: 0, filter: "saturate(0%)" },
					immediate: true,
			  })
			: trailsAPI.start({
					to: { scaleY: 1, opacity: 1, filter: "saturate(100%)" },
					immediate: false,
			  });
	}, [dataRefreshing, trailsAPI]);

	// Tooltip Artifacts
	const {
		tooltipData,
		tooltipLeft,
		tooltipTop,
		tooltipOpen,
		showTooltip,
		hideTooltip,
	} = useTooltip<ObservationType>();

	return (
		<ChartContainer
			colorTheme={colorTheme}
			margins={margins}
			transparentBackground={transparentBackground}
			observer={observeComponent}
			chartTitle={chartTitle}
			titleObserver={observeTitle}
			chartExplanation={chartExplanation}
		>
			{dataRefreshing ? null : (
				<div>
					{tooltipOpen && tooltipData && (
						<Tooltip
							positioning={{ top: tooltipTop, left: tooltipLeft }}
							title={tooltip?.title ?? "Record Details"}
							subtitle={
								isDateTime(tooltipData.observationLabel)
									? tooltipData.observationLabel.toISODate()!
									: tooltipData.observationLabel
							}
							additionalData={tooltipData.data}
							additionalDataFormatter={dataFormatter}
							colorScale={colorScale}
						/>
					)}
					{svgDimensions[0] && svgDimensions[1] && (
						<svg
							viewBox={`0 0 ${svgDimensions[0]} ${svgDimensions[1]}`}
							css={{
								width: svgDimensions[0],
								height: svgDimensions[1],
								overflow: "unset",
							}}
						>
							<title>Stacked Bar Chart</title>
							<GridLines
								lineOption={gridLines}
								lineColor={colorTheme.mid}
								xScale={xScale}
								yScale={yScale}
								width={svgDimensions[0] - margins.horizontal}
								height={svgDimensions[1]}
							/>
							<Group>
								<BarStack
									data={visxDataTransformation}
									keys={seriesNames}
									x={(d) => d.observationLabel}
									xScale={xScale}
									yScale={yScale}
									color={colorScale}
								>
									{(barStacks) => {
										return barStacks.map((barStack, barStackIndex) => {
											return barStack.bars.map((bar, index) => {
												return Number.isNaN(bar.height) ? null : (
													<animated.g
														style={trails[index]}
														css={{
															transformOrigin: "bottom",
														}}
														key={`bar-stack-${barStack.index}-${bar.index}`}
													>
														<animated.rect
															x={bar.x}
															y={bar.y}
															height={bar.height}
															width={bar.width}
															fill={bar.color}
															onMouseLeave={() => {
																hideTooltip();
																setCurrentObservationIndex(null);
															}}
															onMouseMove={(event) => {
																const eventSvgCoords = localPoint(event);
																const left = bar.x + xScale.bandwidth() + 15;

																showTooltip({
																	tooltipData: observations[index],
																	// tooltipTop: eventSvgCoords?.y,
																	tooltipTop: svgDimensions[1],
																	tooltipLeft: left,
																});

																setCurrentObservationIndex(index);
															}}
														/>
														<animated.rect
															css={{
																transition: "all .5s ease",
																opacity:
																	currentObservationIndex === null
																		? 0
																		: currentObservationIndex === index
																		  ? 0
																		  : 0.25,
																pointerEvents: "none",
															}}
															x={bar.x}
															y={bar.y}
															height={bar.height}
															width={bar.width}
															fill={"black"}
														/>
													</animated.g>
												);
											});
										});
									}}
								</BarStack>
							</Group>
							<Group name="axes">
								<AxisRight
									left={svgDimensions[0] - 7}
									// top={5}
									scale={yScale}
									hideZero
									tickLength={0}
									stroke={"transparent"}
									tickFormat={dataFormatter}
									tickStroke={darkGray}
									tickLabelProps={() => ({
										fill: colorTheme.dark,
										fontSize: 11,
										// textAnchor: "end",
										verticalAnchor: "middle",
										x: -5,
									})}
								/>
								{axes?.observationLabel === "bottom" &&
								areObservationsDateTimeLabeled(observations) ? (
									<CustomDateTimeBottomAxis
										yPosition={svgDimensions[1] - 40}
										color={colorTheme.dark}
										dateScale={xScale as ScaleBand<DateTime>}
									/>
								) : null}
							</Group>
						</svg>
					)}

					{displayLegend && (
						<div ref={observeLegend} css={{ paddingTop: 5, paddingBottom: 5 }}>
							<LegendOrdinal
								scale={colorScale}
								direction="column"
								labelMargin="0 0px 0 2px"
								labelFormat={labelFormatter}
								shapeHeight={10}
								shapeWidth={10}
								css={{
									fontSize: 11,
									fontFamily: PRIMARY_FONT,
									textTransform: "capitalize",
									fontWeight: 500,
									color: colorTheme.dark,
								}}
							/>
						</div>
					)}
				</div>
			)}
		</ChartContainer>
	);
}
