import { create, all } from "mathjs";
import * as tf from "@tensorflow/tfjs";
import { store } from "../../../../store/store.js";
import { getLastCloses } from "../../streaming.js";
import { toast } from "react-toastify";
import { saveAs } from "file-saver";

const config = {};
const math = create(all, config);

const showToast = (message) => {
	toast.dark(message, {
		position: "bottom-right",
		autoClose: 1000,
		hideProgressBar: false,
		closeOnClick: true,
		pauseOnHover: true,
		draggable: true,
		theme: "dark",
		progressStyle: { background: "#f74712" },
	});
};

export const runTouchNoTouch = async (tvWidget, epochs, predictionTicks, windowSize) => {
	let prices = [];
	let model;
	const barrierDistance = await localStorage.getItem("barrier");
	let atrValues = [];

	function calculateRSI(prices, period = 14) {
		if (prices.length < period) return null;

		let gains = [];
		let losses = [];

		for (let i = 1; i < prices.length; i++) {
			const difference = prices[i] - prices[i - 1];
			if (difference >= 0) {
				gains.push(difference);
				losses.push(0);
			} else {
				losses.push(Math.abs(difference));
				gains.push(0);
			}
		}

		const averageGain = math.mean(gains.slice(-period));
		const averageLoss = math.mean(losses.slice(-period));

		if (averageLoss === 0) return 100;

		const relativeStrength = averageGain / averageLoss;
		const rsi = 100 - 100 / (1 + relativeStrength);

		return isNaN(rsi) ? null : rsi;
	}

	function calculateBollingerBands(prices, period = 5, stdDevMultiplier = 2) {
		if (prices.length < period) return { upperBand: NaN, lowerBand: NaN, middleBand: NaN };

		const slice = prices.slice(-period);
		const mean = math.mean(slice);
		const variance = math.variance(slice);
		const stdDev = Math.sqrt(variance);

		const upperBand = mean + stdDevMultiplier * stdDev;
		const lowerBand = mean - stdDevMultiplier * stdDev;
		const middleBand = mean;

		return { upperBand, lowerBand, middleBand };
	}

	function calculateATR(closes, period = 3) {
		if (closes.length < period) return NaN;

		const trueRanges = [];
		for (let i = 1; i < closes.length; i++) {
			const trueRange = Math.abs(closes[i] - closes[i - 1]);
			trueRanges.push(trueRange);
		}
		const atr = trueRanges.slice(-period).reduce((acc, range) => acc + range, 0) / period;
		return atr;
	}

	const startRealTimePrediction = () => {
		setInterval(() => {
			const newPrice = Number(store.getState().lastBar.close);

			if (!isNaN(newPrice)) {
				prices.push(newPrice);
				if (prices.length > 5000) {
					prices.shift();
				}

				const recentPrices = prices.slice(-windowSize);
				const normalizedRecentPrices = normalizeData(recentPrices);
				const rsi = calculateRSI(recentPrices, 5);
				const atr = calculateATR(recentPrices, 5);

				if (rsi !== null && !isNaN(atr)) {
					const inputWithIndicators = [...normalizedRecentPrices, rsi, atr];
					const probability = predictNoBarrierReach(model, inputWithIndicators);

					updatePredictionDisplay(rsi, atr, { probUp: probability, probDown: 1 - probability });
					decideTrade(probability);
				} else {
					showToast("No se pudo calcular RSI, Bandas de Bollinger o ATR.");
				}
			} else {
				showToast("No se pudo obtener el nuevo precio.");
			}
		}, 1000);
	};

	const processData = async (prices) => {
		showToast("Procesando datos para el entrenamiento...");
		const { inputs, labels } = createDataset(prices, windowSize, barrierDistance, predictionTicks);

		if (inputs.length === 0) {
			showToast("No hay suficientes datos para crear el conjunto de entrenamiento.");
			return;
		}

		const normalizedInputs = inputs.map((input) => normalizeData(input));

		model = createModel(windowSize + 2);
		await trainModel(model, normalizedInputs, labels);

		startRealTimePrediction();
	};

	function exportDatasetToCSV(inputs, labels) {
		let csvContent = "data:text/csv;charset=utf-8,";

		// Añadir los encabezados al archivo CSV
		csvContent += "Input1,Input2,Input3,...,Label\n"; // Ajusta los nombres según la cantidad de inputs

		// Añadir los datos de inputs y labels
		inputs.forEach((input, index) => {
			const inputData = input.join(","); // Convertir cada input a una línea CSV
			const label = labels[index];
			csvContent += `${inputData},${label}\n`; // Agregar la etiqueta (label) al final
		});

		// Crear un blob y exportar el archivo CSV
		const encodedUri = encodeURI(csvContent);
		const blob = new Blob([csvContent], { type: "text/csv;charset=utf-8;" });
		saveAs(blob, "dataset.csv");
	}

	function createDataset(prices, windowSize, barrierDistance, predictionTicks) {
		const inputs = [];
		const labels = [];
		const parsedBarrierDistance = parseFloat(barrierDistance);

		for (let i = 0; i <= prices.length - windowSize - predictionTicks; i++) {
			const inputPrices = prices.slice(i, i + windowSize);
			const rsi = calculateRSI(inputPrices, 5);
			const atr = calculateATR(inputPrices, 5);

			if (rsi === null || isNaN(atr)) {
				console.log(`Skipping block at index: ${i} due to invalid indicators`);
				continue;
			}

			const currentPrice = inputPrices[inputPrices.length - 1];
			const barrier = currentPrice + parsedBarrierDistance;
			const futurePrices = prices.slice(i + windowSize, i + windowSize + predictionTicks);

			let reachedBarrier;
			if (parsedBarrierDistance > 0) {
				reachedBarrier = futurePrices.some((price) => price >= barrier);
			} else {
				reachedBarrier = futurePrices.some((price) => price <= barrier);
			}

			const label = reachedBarrier ? 1 : 0;
			const inputWithIndicators = [...inputPrices, rsi, atr];
			inputs.push(inputWithIndicators);
			labels.push(label);
		}
		// exportDatasetToCSV(inputs, labels);

		return { inputs, labels };
	}

	function normalizeData(data) {
		const mean = math.mean(data);
		const std = math.std(data);
		return data.map((val) => (val - mean) / std);
	}

	function createModel(inputSize) {
		const model = tf.sequential();
		model.add(tf.layers.dense({ inputShape: [inputSize], units: 64, activation: "relu" }));
		model.add(tf.layers.dense({ units: 32, activation: "relu" }));
		model.add(tf.layers.dense({ units: 1, activation: "sigmoid" }));

		model.compile({
			optimizer: "adam",
			loss: "binaryCrossentropy",
			metrics: ["accuracy"],
		});

		return model;
	}

	async function trainModel(model, inputs, labels) {
		showToast("Entrenando el modelo...");
		const inputTensor = tf.tensor2d(inputs);
		const labelTensor = tf.tensor2d(labels, [labels.length, 1]);

		await model.fit(inputTensor, labelTensor, {
			batchSize: 64,
			epochs: epochs,
			validationSplit: 0.2,
			callbacks: {
				onEpochEnd: (epoch, logs) => {
					const accuracy = logs.acc || logs.accuracy;
					const valAccuracy = logs.val_acc || logs.val_accuracy;
					showToast(
						`Epoch ${epoch + 1}/${epochs} - accuracy: ${(accuracy * 100).toFixed(2)}% - val_accuracy: ${(valAccuracy * 100).toFixed(2)}%`
					);
				},
				onTrainEnd: () => {
					showToast("Entrenamiento completado.");
				},
			},
		});
	}

	function predictNoBarrierReach(model, recentPrices) {
		const inputTensor = tf.tensor2d([recentPrices]);
		const prediction = model.predict(inputTensor);
		const probability = prediction.arraySync()[0][0];
		return probability;
	}

	function updatePredictionDisplay(rsi, atr, probabilities) {
		atrValues.push(atr);
		if (atrValues.length > 180) {
			atrValues.shift();
		}

		// Calculate average ATR
		const atrAverage = atrValues.reduce((sum, val) => sum + val, 0) / atrValues.length;

		let predictionBox = document.getElementById("predictionBox");
		if (!predictionBox) {
			predictionBox = document.createElement("div");
			predictionBox.id = "predictionBox";
			predictionBox.style.cssText = `
            position: absolute;
            bottom: 0%;
            left: 10%;
            transform: translate(-50%, -50%);
            background-color: rgba(23, 27, 38, 0.9);
            padding: 8px;
            border-radius: 8px;
            color: white;
            border: 1px solid #4599d9;
            box-shadow: 0 0 5px #4599d9, 0 0 10px #4599d9;
            font-family: 'Orbitron', sans-serif;
            font-size: 12px;
            line-height: 1.2;
            cursor: move;
            user-select: none;
        `;
			document.body.appendChild(predictionBox);
			makeDraggable(predictionBox);
		}

		predictionBox.innerHTML = `
        <p>RSI: <span style="color: ${rsi > 90 ? "#00ff00" : rsi < 10 ? "#ff0000" : "#ffa500"}">${rsi.toFixed(2)}</span></p>
        <p>ATR: <span style="color: ${atr > atrAverage ? "#ffa500" : "#00ff00"}">${atr.toFixed(5)}</span></p>
		<p>Toca: <span style="color: ${probabilities.probUp * 100 > 70 ? "#00ff00" : probabilities.probUp * 100 < 50 ? "#ff0000" : "#ffa500"}">${(
			probabilities.probUp * 100
		).toFixed(2)}%</span></p>
    `;
	}

	function makeDraggable(element) {
		let pos1 = 0,
			pos2 = 0,
			pos3 = 0,
			pos4 = 0;
		element.onmousedown = dragMouseDown;

		function dragMouseDown(e) {
			e = e || window.event;
			e.preventDefault();
			pos3 = e.clientX;
			pos4 = e.clientY;
			document.onmouseup = closeDragElement;
			document.onmousemove = elementDrag;
		}

		function elementDrag(e) {
			e = e || window.event;
			e.preventDefault();
			pos1 = pos3 - e.clientX;
			pos2 = pos4 - e.clientY;
			pos3 = e.clientX;
			pos4 = e.clientY;
			element.style.top = element.offsetTop - pos2 + "px";
			element.style.left = element.offsetLeft - pos1 + "px";
		}

		function closeDragElement() {
			document.onmouseup = null;
			document.onmousemove = null;
		}
	}

	function decideTrade(probability, threshold = 0.7) {
		const tvWidget = store.getState().chart;
		const lastBar = store.getState().lastBar;

		tvWidget.activeChart().removeAllShapes();

		tvWidget.activeChart().createShape(
			{ time: lastBar.epoch, price: lastBar.high },
			{
				shape: "text",
				text: `P: ${(probability * 100).toFixed(2)}%`,
				overrides: {
					backgroundColor: "transparent",
					color: probability > 0.85 ? "green" : "red",
					fontsize: 14,
					bold: true,
					fixedPosition: true,
					vertAlign: "top",
					horzAlign: "right",
				},
				zOrder: "top",
			}
		);
	}

	const init = async () => {
		const ohlcData = await getLastCloses("1T", 5000);

		if (ohlcData.length > 0) {
			prices = ohlcData;
			await processData(prices);
		} else {
			showToast("No se obtuvieron datos históricos.");
		}
	};
	await init();
};
