This article walks through the full production pipeline for building and using a Temporal Fusion Transformer (TFT) to predict Bitcoin’s next-hour trend (bullish or bearish).

We’ll cover:

  1. Requirements and environment setup.
  2. Dataset preparation (fetching BTC data and adding features).
  3. Dataset creation for TFT.
  4. Training the model.
  5. Running live predictions with entry/target/stop.
  6. Switching between CPU and GPU.

All scripts are provided in full. πŸš€


1. Requirements (Environment Setup)

We start with a requirements-gpu.txt file:

# Core deep learning stack (CUDA 12.1 build)
torch==2.1.2+cu121
torchvision==0.16.2+cu121
torchaudio==2.1.2+cu121
--index-url https://download.pytorch.org/whl/cu121

# PyTorch Lightning (renamed to lightning >=2.0)
lightning==2.1.4

# Forecasting & Tabular models
pytorch-forecasting==1.0.0
pytorch-tabnet==4.1.0

# Data handling
pandas>=2.0.0
numpy>=1.24.0
scikit-learn>=1.3.0

# Utils
python-dotenv>=1.0.0
tqdm>=4.66.0

# Optional (speedups + tuning)
optuna>=3.5.0
joblib>=1.3.0

Why these?

  • PyTorch β†’ core deep learning framework.
  • Lightning β†’ simpler training loop management.
  • PyTorch Forecasting β†’ provides TFT.
  • Pandas, Numpy, Sklearn β†’ data wrangling.
  • Optuna/Joblib β†’ tuning and speedups.

πŸ‘‰ Install with:

pip install -r requirements-gpu.txt

For CPU only:

pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 lightning==2.1.4 pytorch-forecasting==1.0.0 pandas numpy scikit-learn

2. Dataset Preparation (btc_data.py)

This script fetches raw BTC data from Binance, resamples it, and engineers features.

# btc_data.py
import requests
import pandas as pd
import time
from datetime import datetime, timedelta, timezone
import os
import numpy as np

# === CONFIG ===
SYMBOL = "BTCUSDT"
INTERVAL = "1m"
DAYS = 365
OUTPUT_DIR = "data"
os.makedirs(OUTPUT_DIR, exist_ok=True)

BASE_URL = "https://api.binance.com/api/v3/klines"
LIMIT = 1000

def fetch_klines(symbol, interval, start_time, end_time, limit=1000):
    url = f"{BASE_URL}?symbol={symbol}&interval={interval}&limit={limit}&startTime={start_time}&endTime={end_time}"
    try:
        data = requests.get(url, timeout=10).json()
        return data
    except Exception as e:
        print("⚠️ Error fetching:", e)
        return []

def get_binance_data(symbol, interval, days, limit=1000):
    end_time = int(datetime.now(timezone.utc).timestamp() * 1000)
    start_time = int((datetime.now(timezone.utc) - timedelta(days=days)).timestamp() * 1000)
    all_data = []
    while start_time < end_time:
        data = fetch_klines(symbol, interval, start_time, end_time, limit)
        if not data or isinstance(data, dict):
            print("⚠️ Empty response, retrying...")
            time.sleep(1)
            continue
        all_data.extend(data)
        last_time = data[-1][0]
        start_time = last_time + 60_000
        time.sleep(0.3)
    return all_data

def add_features(df: pd.DataFrame) -> pd.DataFrame:
    df["range"] = df["high"] - df["low"]
    df["ema9"] = df["close"].ewm(span=9).mean()
    df["ema21"] = df["close"].ewm(span=21).mean()
    df["ema50"] = df["close"].ewm(span=50).mean()
    high_low = df["high"] - df["low"]
    high_close = (df["high"] - df["close"].shift()).abs()
    low_close = (df["low"] - df["close"].shift()).abs()
    df["atr14"] = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1).rolling(14).mean()
    df["body_pct"] = (df["close"] - df["open"]).abs() / (df["range"] + 1e-9)
    df["wick_upper_pct"] = (df["high"] - df[["close", "open"]].max(axis=1)) / (df["range"] + 1e-9)
    df["wick_lower_pct"] = (df[["close", "open"]].min(axis=1) - df["low"]) / (df["range"] + 1e-9)
    df["rvol20"] = df["volume"] / df["volume"].rolling(20).mean()
    df["atr_pct"] = df["atr14"] / df["close"]
    return df

def save_with_features(df, filename):
    df_feat = add_features(df.copy())
    df_feat.dropna(inplace=True)
    df_feat.to_csv(filename, index=True)
    print(f"βœ… Saved {len(df_feat)} rows β†’ {filename}")

def main():
    print(f"⏳ Fetching {DAYS} days of {SYMBOL} {INTERVAL} data...")
    raw_data = get_binance_data(SYMBOL, INTERVAL, DAYS, limit=LIMIT)
    df = pd.DataFrame(raw_data, columns=[
        "time", "open", "high", "low", "close", "volume",
        "close_time", "quote_asset_volume", "trades",
        "taker_buy_base", "taker_buy_quote", "ignore"
    ])
    df = df[["time", "open", "high", "low", "close", "volume"]]
    df["time"] = pd.to_datetime(df["time"], unit="ms")
    df = df.astype({"open": float, "high": float, "low": float, "close": float, "volume": float})
    df = df.set_index("time")

    # Save raw and resampled data
    save_with_features(df, os.path.join(OUTPUT_DIR, f"{SYMBOL}_1m.csv"))
    df_5m = df.resample("5min").agg({"open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum"}).dropna()
    save_with_features(df_5m, os.path.join(OUTPUT_DIR, f"{SYMBOL}_5m.csv"))
    df_1h = df.resample("1h").agg({"open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum"}).dropna()
    save_with_features(df_1h, os.path.join(OUTPUT_DIR, f"{SYMBOL}_1h.csv"))

if __name__ == "__main__":
    main()

πŸ‘‰ Run:

python btc_data.py

Generates:

  • data/BTCUSDT_1m.csv
  • data/BTCUSDT_5m.csv
  • data/BTCUSDT_1h.csv (we’ll use this one).

3. Training TFT (btc_1h_trend_mirror_trainV2.py)

# btc_1h_trend_mirror_trainV2.py
import pandas as pd, numpy as np
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.metrics import CrossEntropy
import pytorch_lightning as pl
import torch

df = pd.read_csv("data/BTCUSDT_1h.csv")
df = df.dropna().reset_index(drop=True)

# Label: bullish if close > EMA21
df["y"] = np.where(df["close"] > df["ema21"], 1, 0)
df["time_idx"] = np.arange(len(df))
df["series_id"] = "BTC"

feature_cols = [
    "open","high","low","close","volume",
    "ema9","ema21","ema50","atr14","range","body_pct",
    "wick_upper_pct","wick_lower_pct","rvol20","atr_pct"
]

training = TimeSeriesDataSet(
    df,
    time_idx="time_idx",
    target="y",
    group_ids=["series_id"],
    min_encoder_length=48,
    max_encoder_length=48,
    min_prediction_length=1,
    max_prediction_length=1,
    time_varying_unknown_reals=feature_cols,
    target_normalizer=NaNLabelEncoder(),
    categorical_encoders={"series_id": NaNLabelEncoder().fit(df.series_id)},
    add_relative_time_idx=True,
    add_target_scales=True,
)

train_cutoff = int(df["time_idx"].max() * 0.8)
train_ds = TimeSeriesDataSet.from_dataset(training, df[df.time_idx <= train_cutoff])
val_ds   = TimeSeriesDataSet.from_dataset(training, df[df.time_idx > train_cutoff], predict=True)

train_loader = train_ds.to_dataloader(train=True, batch_size=64, num_workers=0)
val_loader   = val_ds.to_dataloader(train=False, batch_size=64, num_workers=0)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=1e-3,
    hidden_size=32,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=2,
    loss=CrossEntropy(),
)

trainer = pl.Trainer(
    max_epochs=30,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    gradient_clip_val=0.1,
)

trainer.fit(tft, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.save_checkpoint("MODELS/btc_1h_trend_mirror.ckpt")
print("βœ… Model saved at MODELS/btc_1h_trend_mirror.ckpt")

πŸ‘‰ Run:

python btc_1h_trend_mirror_trainV2.py

4. Live Prediction (TFT_predict.py)

# TFT_predict.py
import torch, requests, pandas as pd, numpy as np
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder

CKPT_PATH = "MODELS/btc_1h_trend_mirror.ckpt"
WINDOW, PRED_LEN, THRESHOLD = 48, 1, 0.55
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fetch_binance(symbol="BTCUSDT", interval="1h", limit=500):
    url = "https://api.binance.com/api/v3/klines"
    resp = requests.get(url, params={"symbol": symbol, "interval": interval, "limit": limit})
    data = resp.json()
    df = pd.DataFrame(data, columns=["time","open","high","low","close","volume",
                                     "close_time","qav","num_trades","taker_base_vol","taker_quote_vol","ignore"])
    df["time"]   = pd.to_datetime(df["time"], unit="ms")
    df["open"]   = df["open"].astype(float)
    df["high"]   = df["high"].astype(float)
    df["low"]    = df["low"].astype(float)
    df["close"]  = df["close"].astype(float)
    df["volume"] = df["volume"].astype(float)
    return df[["time","open","high","low","close","volume"]]

df = fetch_binance()
df = df.sort_values("time").reset_index(drop=True)
df["time_idx"] = np.arange(len(df))
df["series_id"] = "BTC"

base_ds = TimeSeriesDataSet(
    df,
    time_idx="time_idx",
    target="y",
    group_ids=["series_id"],
    min_encoder_length=WINDOW,
    max_encoder_length=WINDOW,
    min_prediction_length=PRED_LEN,
    max_prediction_length=PRED_LEN,
    time_varying_unknown_reals=feature_cols,
    target_normalizer=NaNLabelEncoder(),
    categorical_encoders={"series_id": NaNLabelEncoder().fit(df.series_id)},
    add_relative_time_idx=True,
    add_target_scales=True,
)

loader = base_ds.to_dataloader(train=False, batch_size=64)

model = TemporalFusionTransformer.load_from_checkpoint(CKPT_PATH, map_location=device)
model.to(device).eval()

with torch.no_grad():
    preds = model.predict(loader, mode="raw")
logits = preds[0] if isinstance(preds, (list, tuple)) else preds
if logits.ndim == 3: logits = logits[:, -1, :]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
prob_bearish, prob_bullish = probs[-1]
signal_ai = "Bullish" if prob_bullish > prob_bearish else "Bearish"
entry_price = df.iloc[-1]["close"]
target_price = entry_price * (1.02 if signal_ai=="Bullish" else 0.98)
stop_loss   = entry_price * (0.995 if signal_ai=="Bullish" else 1.005)
expected_points = (target_price - entry_price) if signal_ai=="Bullish" else (entry_price - target_price)

live_output = {
    "signal_dtw": signal_ai,
    "signal_ai": signal_ai,
    "probability": round(max(prob_bullish, prob_bearish), 2),
    "entry_price": round(float(entry_price), 2),
    "target_price": round(float(target_price), 2),
    "stop_loss": round(float(stop_loss), 2),
    "expected_points": round(float(expected_points), 2),
}
print(f"βœ… Live Prediction: {live_output}")

Example Output:

βœ… Live Prediction: {
  'signal_dtw': 'Bullish',
  'signal_ai': 'Bullish',
  'probability': 0.95,
  'entry_price': 110849.85,
  'target_price': 112999.07,
  'stop_loss': 111409.20,
  'expected_points': 2149.22
}

5. Switching Between GPU and CPU

  • Auto-detect:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • Force CPU:
device = torch.device("cpu")
  • Force GPU:
device = torch.device("cuda")

βœ… Final Workflow

  1. Install requirements pip install -r requirements-gpu.txt
  2. Prepare dataset python btc_data.py
  3. Train model python btc_1h_trend_mirror_trainV2.py
  4. Run live prediction python TFT_predict.py

πŸ“Œ Summary

  • btc_data.py β†’ fetch + feature engineering.
  • btc_1h_trend_mirror_trainV2.py β†’ create dataset + train TFT.
  • TFT_predict.py β†’ run live predictions with probability, entry, target, stop.
  • Works on CPU or GPU with one line change.

This is a production-ready pipeline for BTC trend prediction with TFT.


Was this article helpful?
YesNo

Similar Posts