
This article walks through the full production pipeline for building and using a Temporal Fusion Transformer (TFT) to predict Bitcoinβs next-hour trend (bullish or bearish).
Weβll cover:
- Requirements and environment setup.
- Dataset preparation (fetching BTC data and adding features).
- Dataset creation for TFT.
- Training the model.
- Running live predictions with entry/target/stop.
- Switching between CPU and GPU.
All scripts are provided in full. π
1. Requirements (Environment Setup)
We start with a requirements-gpu.txt
file:
# Core deep learning stack (CUDA 12.1 build)
torch==2.1.2+cu121
torchvision==0.16.2+cu121
torchaudio==2.1.2+cu121
--index-url https://download.pytorch.org/whl/cu121
# PyTorch Lightning (renamed to lightning >=2.0)
lightning==2.1.4
# Forecasting & Tabular models
pytorch-forecasting==1.0.0
pytorch-tabnet==4.1.0
# Data handling
pandas>=2.0.0
numpy>=1.24.0
scikit-learn>=1.3.0
# Utils
python-dotenv>=1.0.0
tqdm>=4.66.0
# Optional (speedups + tuning)
optuna>=3.5.0
joblib>=1.3.0
Why these?
- PyTorch β core deep learning framework.
- Lightning β simpler training loop management.
- PyTorch Forecasting β provides TFT.
- Pandas, Numpy, Sklearn β data wrangling.
- Optuna/Joblib β tuning and speedups.
π Install with:
pip install -r requirements-gpu.txt
For CPU only:
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 lightning==2.1.4 pytorch-forecasting==1.0.0 pandas numpy scikit-learn
2. Dataset Preparation (btc_data.py
)
This script fetches raw BTC data from Binance, resamples it, and engineers features.
# btc_data.py
import requests
import pandas as pd
import time
from datetime import datetime, timedelta, timezone
import os
import numpy as np
# === CONFIG ===
SYMBOL = "BTCUSDT"
INTERVAL = "1m"
DAYS = 365
OUTPUT_DIR = "data"
os.makedirs(OUTPUT_DIR, exist_ok=True)
BASE_URL = "https://api.binance.com/api/v3/klines"
LIMIT = 1000
def fetch_klines(symbol, interval, start_time, end_time, limit=1000):
url = f"{BASE_URL}?symbol={symbol}&interval={interval}&limit={limit}&startTime={start_time}&endTime={end_time}"
try:
data = requests.get(url, timeout=10).json()
return data
except Exception as e:
print("β οΈ Error fetching:", e)
return []
def get_binance_data(symbol, interval, days, limit=1000):
end_time = int(datetime.now(timezone.utc).timestamp() * 1000)
start_time = int((datetime.now(timezone.utc) - timedelta(days=days)).timestamp() * 1000)
all_data = []
while start_time < end_time:
data = fetch_klines(symbol, interval, start_time, end_time, limit)
if not data or isinstance(data, dict):
print("β οΈ Empty response, retrying...")
time.sleep(1)
continue
all_data.extend(data)
last_time = data[-1][0]
start_time = last_time + 60_000
time.sleep(0.3)
return all_data
def add_features(df: pd.DataFrame) -> pd.DataFrame:
df["range"] = df["high"] - df["low"]
df["ema9"] = df["close"].ewm(span=9).mean()
df["ema21"] = df["close"].ewm(span=21).mean()
df["ema50"] = df["close"].ewm(span=50).mean()
high_low = df["high"] - df["low"]
high_close = (df["high"] - df["close"].shift()).abs()
low_close = (df["low"] - df["close"].shift()).abs()
df["atr14"] = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1).rolling(14).mean()
df["body_pct"] = (df["close"] - df["open"]).abs() / (df["range"] + 1e-9)
df["wick_upper_pct"] = (df["high"] - df[["close", "open"]].max(axis=1)) / (df["range"] + 1e-9)
df["wick_lower_pct"] = (df[["close", "open"]].min(axis=1) - df["low"]) / (df["range"] + 1e-9)
df["rvol20"] = df["volume"] / df["volume"].rolling(20).mean()
df["atr_pct"] = df["atr14"] / df["close"]
return df
def save_with_features(df, filename):
df_feat = add_features(df.copy())
df_feat.dropna(inplace=True)
df_feat.to_csv(filename, index=True)
print(f"β
Saved {len(df_feat)} rows β {filename}")
def main():
print(f"β³ Fetching {DAYS} days of {SYMBOL} {INTERVAL} data...")
raw_data = get_binance_data(SYMBOL, INTERVAL, DAYS, limit=LIMIT)
df = pd.DataFrame(raw_data, columns=[
"time", "open", "high", "low", "close", "volume",
"close_time", "quote_asset_volume", "trades",
"taker_buy_base", "taker_buy_quote", "ignore"
])
df = df[["time", "open", "high", "low", "close", "volume"]]
df["time"] = pd.to_datetime(df["time"], unit="ms")
df = df.astype({"open": float, "high": float, "low": float, "close": float, "volume": float})
df = df.set_index("time")
# Save raw and resampled data
save_with_features(df, os.path.join(OUTPUT_DIR, f"{SYMBOL}_1m.csv"))
df_5m = df.resample("5min").agg({"open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum"}).dropna()
save_with_features(df_5m, os.path.join(OUTPUT_DIR, f"{SYMBOL}_5m.csv"))
df_1h = df.resample("1h").agg({"open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum"}).dropna()
save_with_features(df_1h, os.path.join(OUTPUT_DIR, f"{SYMBOL}_1h.csv"))
if __name__ == "__main__":
main()
π Run:
python btc_data.py
Generates:
data/BTCUSDT_1m.csv
data/BTCUSDT_5m.csv
data/BTCUSDT_1h.csv
(weβll use this one).
3. Training TFT (btc_1h_trend_mirror_trainV2.py
)
# btc_1h_trend_mirror_trainV2.py
import pandas as pd, numpy as np
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.metrics import CrossEntropy
import pytorch_lightning as pl
import torch
df = pd.read_csv("data/BTCUSDT_1h.csv")
df = df.dropna().reset_index(drop=True)
# Label: bullish if close > EMA21
df["y"] = np.where(df["close"] > df["ema21"], 1, 0)
df["time_idx"] = np.arange(len(df))
df["series_id"] = "BTC"
feature_cols = [
"open","high","low","close","volume",
"ema9","ema21","ema50","atr14","range","body_pct",
"wick_upper_pct","wick_lower_pct","rvol20","atr_pct"
]
training = TimeSeriesDataSet(
df,
time_idx="time_idx",
target="y",
group_ids=["series_id"],
min_encoder_length=48,
max_encoder_length=48,
min_prediction_length=1,
max_prediction_length=1,
time_varying_unknown_reals=feature_cols,
target_normalizer=NaNLabelEncoder(),
categorical_encoders={"series_id": NaNLabelEncoder().fit(df.series_id)},
add_relative_time_idx=True,
add_target_scales=True,
)
train_cutoff = int(df["time_idx"].max() * 0.8)
train_ds = TimeSeriesDataSet.from_dataset(training, df[df.time_idx <= train_cutoff])
val_ds = TimeSeriesDataSet.from_dataset(training, df[df.time_idx > train_cutoff], predict=True)
train_loader = train_ds.to_dataloader(train=True, batch_size=64, num_workers=0)
val_loader = val_ds.to_dataloader(train=False, batch_size=64, num_workers=0)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=1e-3,
hidden_size=32,
attention_head_size=4,
dropout=0.1,
hidden_continuous_size=16,
output_size=2,
loss=CrossEntropy(),
)
trainer = pl.Trainer(
max_epochs=30,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
gradient_clip_val=0.1,
)
trainer.fit(tft, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.save_checkpoint("MODELS/btc_1h_trend_mirror.ckpt")
print("β
Model saved at MODELS/btc_1h_trend_mirror.ckpt")
π Run:
python btc_1h_trend_mirror_trainV2.py
4. Live Prediction (TFT_predict.py
)
# TFT_predict.py
import torch, requests, pandas as pd, numpy as np
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import NaNLabelEncoder
CKPT_PATH = "MODELS/btc_1h_trend_mirror.ckpt"
WINDOW, PRED_LEN, THRESHOLD = 48, 1, 0.55
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def fetch_binance(symbol="BTCUSDT", interval="1h", limit=500):
url = "https://api.binance.com/api/v3/klines"
resp = requests.get(url, params={"symbol": symbol, "interval": interval, "limit": limit})
data = resp.json()
df = pd.DataFrame(data, columns=["time","open","high","low","close","volume",
"close_time","qav","num_trades","taker_base_vol","taker_quote_vol","ignore"])
df["time"] = pd.to_datetime(df["time"], unit="ms")
df["open"] = df["open"].astype(float)
df["high"] = df["high"].astype(float)
df["low"] = df["low"].astype(float)
df["close"] = df["close"].astype(float)
df["volume"] = df["volume"].astype(float)
return df[["time","open","high","low","close","volume"]]
df = fetch_binance()
df = df.sort_values("time").reset_index(drop=True)
df["time_idx"] = np.arange(len(df))
df["series_id"] = "BTC"
base_ds = TimeSeriesDataSet(
df,
time_idx="time_idx",
target="y",
group_ids=["series_id"],
min_encoder_length=WINDOW,
max_encoder_length=WINDOW,
min_prediction_length=PRED_LEN,
max_prediction_length=PRED_LEN,
time_varying_unknown_reals=feature_cols,
target_normalizer=NaNLabelEncoder(),
categorical_encoders={"series_id": NaNLabelEncoder().fit(df.series_id)},
add_relative_time_idx=True,
add_target_scales=True,
)
loader = base_ds.to_dataloader(train=False, batch_size=64)
model = TemporalFusionTransformer.load_from_checkpoint(CKPT_PATH, map_location=device)
model.to(device).eval()
with torch.no_grad():
preds = model.predict(loader, mode="raw")
logits = preds[0] if isinstance(preds, (list, tuple)) else preds
if logits.ndim == 3: logits = logits[:, -1, :]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
prob_bearish, prob_bullish = probs[-1]
signal_ai = "Bullish" if prob_bullish > prob_bearish else "Bearish"
entry_price = df.iloc[-1]["close"]
target_price = entry_price * (1.02 if signal_ai=="Bullish" else 0.98)
stop_loss = entry_price * (0.995 if signal_ai=="Bullish" else 1.005)
expected_points = (target_price - entry_price) if signal_ai=="Bullish" else (entry_price - target_price)
live_output = {
"signal_dtw": signal_ai,
"signal_ai": signal_ai,
"probability": round(max(prob_bullish, prob_bearish), 2),
"entry_price": round(float(entry_price), 2),
"target_price": round(float(target_price), 2),
"stop_loss": round(float(stop_loss), 2),
"expected_points": round(float(expected_points), 2),
}
print(f"β
Live Prediction: {live_output}")
Example Output:
β
Live Prediction: {
'signal_dtw': 'Bullish',
'signal_ai': 'Bullish',
'probability': 0.95,
'entry_price': 110849.85,
'target_price': 112999.07,
'stop_loss': 111409.20,
'expected_points': 2149.22
}
5. Switching Between GPU and CPU
- Auto-detect:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- Force CPU:
device = torch.device("cpu")
- Force GPU:
device = torch.device("cuda")
β Final Workflow
- Install requirements
pip install -r requirements-gpu.txt
- Prepare dataset
python btc_data.py
- Train model
python btc_1h_trend_mirror_trainV2.py
- Run live prediction
python TFT_predict.py
π Summary
btc_data.py
β fetch + feature engineering.btc_1h_trend_mirror_trainV2.py
β create dataset + train TFT.TFT_predict.py
β run live predictions with probability, entry, target, stop.- Works on CPU or GPU with one line change.
This is a production-ready pipeline for BTC trend prediction with TFT.