import * as tf from "@tensorflow/tfjs";
import { store } from "../../../../store/store";
import { getLastCloses } from "../../streaming.js";
import { toast } from "react-toastify";
import { create, all } from "mathjs";

const math = create(all);

const showToast = (message) => {
	toast.dark(message, {
		position: "bottom-right",
		autoClose: 1000,
		hideProgressBar: false,
		closeOnClick: true,
		pauseOnHover: true,
		draggable: true,
		theme: "dark",
		progressStyle: { background: "#f74712" },
	});
};

export const runRiseFall = async (tvWidget, epochs, predictionTicks, windowSize) => {
	let prices = [];
	let model;

	function calculateRSI(prices, period = 14) {
		if (prices.length < period) return null; // Verificación de suficientes datos

		let gains = [];
		let losses = [];

		for (let i = 1; i < prices.length; i++) {
			const difference = prices[i] - prices[i - 1];
			if (difference >= 0) {
				gains.push(difference);
				losses.push(0);
			} else {
				losses.push(Math.abs(difference));
				gains.push(0);
			}
		}

		const averageGain = math.mean(gains.slice(-period));
		const averageLoss = math.mean(losses.slice(-period));

		if (averageLoss === 0) return 100;

		const relativeStrength = averageGain / averageLoss;
		const rsi = 100 - 100 / (1 + relativeStrength);

		return isNaN(rsi) ? null : rsi;
	}

	function calculateEMA(prices, period) {
		if (prices.length < period) return null; // Verificación de suficientes datos

		const multiplier = 2 / (period + 1);
		let ema = prices[0];

		for (let i = 1; i < prices.length; i++) {
			ema = (prices[i] - ema) * multiplier + ema;
		}

		return ema;
	}

	let macdHistory = [];

	function calculateMACD(prices, fastPeriod = 4, slowPeriod = 7, signalPeriod = 3) {
		if (prices.length < slowPeriod) return { macdLine: NaN, signalLine: NaN };

		const emaFast = calculateEMA(prices, fastPeriod);
		const emaSlow = calculateEMA(prices, slowPeriod);

		if (emaFast === null || emaSlow === null) return { macdLine: NaN, signalLine: NaN };

		const macdLine = emaFast - emaSlow;
		macdHistory.push(macdLine);

		if (macdHistory.length < signalPeriod) {
			return { macdLine, signalLine: NaN };
		}

		const signalLine = calculateEMA(macdHistory.slice(-signalPeriod), signalPeriod);

		return { macdLine, signalLine };
	}

	function calculateMomentum(prices, period) {
		if (prices.length < period) return null;

		return prices[prices.length - 1] - prices[prices.length - period];
	}

	function normalizeData(data) {
		const tensorData = tf.tensor(data);
		const mean = tensorData.mean();
		const variance = tensorData.sub(mean).square().mean();
		const std = variance.sqrt();

		const normalizedTensor = tensorData.sub(mean).div(std);
		const normalizedData = normalizedTensor.arraySync();

		tensorData.dispose();
		mean.dispose();
		variance.dispose();
		std.dispose();
		normalizedTensor.dispose();

		return normalizedData;
	}

	function createDatasetWithIndicators(
		prices,
		windowSize,
		predictionTicks,
		momentumPeriod = 5,
		fastPeriod = 3,
		slowPeriod = 7,
		signalPeriod = 5,
		rsiPeriod = 7
	) {
		showToast("Creando Dataset Optimizado.");

		const inputs = [];
		const labels = [];

		for (let i = 0; i < prices.length - windowSize - predictionTicks; i++) {
			const inputPrices = prices.slice(i, i + windowSize);
			const rsi = calculateRSI(inputPrices, rsiPeriod);
			const { macdLine, signalLine } = calculateMACD(inputPrices, fastPeriod, slowPeriod, signalPeriod);
			const momentum = calculateMomentum(inputPrices, momentumPeriod);

			if (rsi === null || isNaN(macdLine) || isNaN(signalLine) || momentum === null) {
				continue;
			}

			const inputWithIndicators = [...inputPrices, rsi, macdLine, signalLine, momentum];
			const currentPrice = prices[i + windowSize - 1];
			const futurePrice = prices[i + windowSize + predictionTicks - 1];
			const label = futurePrice > currentPrice ? 1 : 0;

			if (inputWithIndicators.some(isNaN) || isNaN(label)) {
				continue;
			}

			inputs.push(inputWithIndicators);
			labels.push(label);
		}

		if (inputs.length === 0 || labels.length === 0) {
			console.log("No se generaron suficientes entradas y etiquetas para el conjunto de entrenamiento.");
		}

		return { inputs, labels };
	}

	// Función para entrenar el modelo
	async function trainModel(model, inputs, labels) {
		const batchSize = 64;
		const validationSplit = 0.2;

		const inputTensor = tf.tensor2d(inputs);
		const labelTensor = tf.tensor2d(labels, [labels.length, 1]);

		await model.fit(inputTensor, labelTensor, {
			batchSize: batchSize,
			epochs: epochs,
			validationSplit: validationSplit,
			callbacks: {
				onEpochEnd: (epoch, logs) => {
					showToast(`Epoch ${epoch + 1}/${epochs} - loss: ${logs.loss.toFixed(6)} - Accuracy: ${(logs.acc * 100).toFixed(2)}%`);
				},
				onTrainEnd: () => {
					showToast("Entrenamiento completado.");
				},
			},
		});

		inputTensor.dispose();
		labelTensor.dispose();
	}

	// Crear modelo con indicadores
	function createModelWithIndicators(windowSize) {
		const model = tf.sequential();
		model.add(tf.layers.dense({ inputShape: [windowSize + 4], units: 64, activation: "relu" }));
		model.add(tf.layers.dense({ units: 32, activation: "relu" }));
		model.add(tf.layers.dropout({ rate: 0.2 }));
		model.add(tf.layers.dense({ units: 1, activation: "sigmoid" }));

		model.compile({ optimizer: "adam", loss: "binaryCrossentropy", metrics: ["accuracy"] });

		return model;
	}

	const processData = async (prices) => {
		showToast("Procesando datos para el entrenamiento...");
		const { inputs, labels } = createDatasetWithIndicators(prices, windowSize, predictionTicks);

		if (inputs.length === 0) {
			console.log("No hay suficientes datos para crear el conjunto de entrenamiento.");
			return;
		}

		const normalizedInputs = inputs.map((input) => normalizeData(input));
		model = createModelWithIndicators(windowSize);
		await trainModel(model, normalizedInputs, labels);
		startTickCollection();
		startPredictionInterval();
	};

	const init = async () => {
		showToast("Obteniendo datos históricos...");
		try {
			prices = await getLastCloses("1T", 5000);
			console.log(`Cantidad de datos históricos obtenidos: ${prices.length}`);
			await processData(prices);
		} catch (error) {
			console.log(error.message);
		}
	};

	await init();

	function startTickCollection() {
		setInterval(() => {
			const newPrice = Number(store.getState().lastBar.close);

			if (!isNaN(newPrice)) {
				prices.push(newPrice);

				if (prices.length > 5000) {
					prices.shift();
				}
			} else {
				console.log("No se pudo obtener el nuevo precio.", newPrice);
			}
		}, 1000);
	}

	function startPredictionInterval() {
		setInterval(() => {
			if (prices.length >= windowSize) {
				const recentPrices = prices.slice(-windowSize);
				const normalizedRecentPrices = normalizeData(recentPrices);
				const rsi = calculateRSI(recentPrices, 14);
				const { macdLine, signalLine } = calculateMACD(recentPrices, 12, 26, 9);
				const momentum = calculateMomentum(recentPrices, 5);

				if (rsi !== null && !isNaN(macdLine) && !isNaN(signalLine) && momentum !== null) {
					const inputWithIndicators = [...normalizedRecentPrices, rsi, macdLine, signalLine, momentum];
					const { probUp, probDown } = predictRiseFall(model, inputWithIndicators);
					updatePredictionDisplay(rsi, macdLine, signalLine, momentum, { probUp, probDown });
				} else {
					console.log("No se pudo calcular RSI, MACD o Momentum.");
				}
			} else {
				console.log("No hay suficientes datos para hacer una predicción.");
			}
		}, 1000);
	}

	function predictRiseFall(model, inputData) {
		const inputTensor = tf.tensor2d([inputData]);
		const prediction = model.predict(inputTensor);
		const probability = prediction.arraySync()[0][0];

		inputTensor.dispose();
		prediction.dispose();

		const probDown = 1 - probability;

		return { probUp: probability, probDown: probDown };
	}

	function decideTrade(prediction) {
		const lastBar = store.getState().lastBar;

		tvWidget.activeChart().createShape(
			{ time: lastBar.epoch, price: Number(lastBar.close) },
			{
				shape: prediction === 1 ? "arrow_up" : "arrow_down",
				text: prediction === 1 ? "RISE" : "FALL",
				overrides: {
					color: prediction === 1 ? "#4599d9" : "#FF0000",
					textColor: "white",
					fontsize: 14,
					bold: true,
				},
				zOrder: "top",
			}
		);
	}
};

function updatePredictionDisplay(rsi, macdLine, signalLine, momentum, probabilities) {
	const predictionBox = document.getElementById("predictionBox");
	if (!predictionBox) {
		const box = document.createElement("div");
		box.id = "predictionBox";
		box.style.cssText = `
			position: absolute;
			bottom: 0%;
            left: 10%;
			transform: translate(-50%, -50%);
			background-color: rgba(23, 27, 38, 0.9);
			padding: 8px;
			border-radius: 8px;
			color: white;
			border: 1px solid #4599d9;
			box-shadow: 0 0 5px #4599d9, 0 0 10px #4599d9;
			font-family: 'Orbitron', sans-serif;
			font-size: 12px;
			line-height: 1.2;
			cursor: move;
			user-select: none;
		`;
		document.body.appendChild(box);
		makeDraggable(box);
	}

	// Muestra la probabilidad de subida y de bajada
	document.getElementById("predictionBox").innerHTML = `
		<p>RSI: <span style="color: ${rsi > 90 ? "#00ff00" : "#ffa500"}">${rsi.toFixed(2)}</span></p>		
		<p>MACD Line: ${macdLine.toFixed(5)}</p>
		<p>Signal Line: ${signalLine.toFixed(5)}</p>
		<p>Momentum: ${momentum.toFixed(5)}</p>
		<p>Up Probability: ${(probabilities.probUp * 100).toFixed(2)}%</p>
		<p>Down Probability: ${(probabilities.probDown * 100).toFixed(2)}%</p>
	`;
}

// Hacer que la ventana flotante sea movible
function makeDraggable(element) {
	let pos1 = 0,
		pos2 = 0,
		pos3 = 0,
		pos4 = 0;
	element.onmousedown = dragMouseDown;

	function dragMouseDown(e) {
		e = e || window.event;
		e.preventDefault();
		pos3 = e.clientX;
		pos4 = e.clientY;
		document.onmouseup = closeDragElement;
		document.onmousemove = elementDrag;
	}

	function elementDrag(e) {
		e = e || window.event;
		e.preventDefault();
		pos1 = pos3 - e.clientX;
		pos2 = pos4 - e.clientY;
		pos3 = e.clientX;
		pos4 = e.clientY;
		element.style.top = element.offsetTop - pos2 + "px";
		element.style.left = element.offsetLeft - pos1 + "px";
	}

	function closeDragElement() {
		document.onmouseup = null;
		document.onmousemove = null;
	}
}
