Slicing the Silicon: Eine hardwarezentrierte Tiefenanalyse der Tensor Parallelism
Das effektive Training von Large Language Models (LLMs) erfordert mehr als einfache Datenparallelität. Wenn die Gewichtsmatrizen eines Modells zu groß sind, um in den VRAM einer einzelnen GPU zu passen – oder wenn der Speicherbedarf für Zwischenaktivierungen explodiert – müssen wir das Modell selbst aufteilen.
Tensor Parallelism (TP) ist eine Technik der Modellparallelität innerhalb einer Schicht. Im Gegensatz zur Pipeline Parallelism (bei der das Modell vertikal nach Schichten aufgeteilt wird) oder Data Parallelism (bei der das Modell repliziert wird), teilt TP das Modell horizontal, indem einzelne Tensoren (Gewichtsmatrizen) auf mehrere Geräte verteilt werden. Dadurch kann ein Cluster von GPUs wie ein einziger, massiver Beschleuniger agieren und eine einzelne Operation gleichzeitig verarbeiten.
Dieser Artikel analysiert die hardwareseitigen Mechanismen von TP, stellt eine "bare metal"-Implementierung in PyTorch bereit und untersucht die entscheidenden Hardware-Abhängigkeiten.
1. Die Hardware-Mechanik: Die Matrix aufteilen
Auf Registerebene wird das Training neuronaler Netzwerke von Matrixmultiplikation (MatMul) dominiert: . Bei Tensor Parallelism nutzen wir die Eigenschaften der linearen Algebra, um genau diese Operation auf GPUs zu verteilen.
Es gibt zwei Hauptmethoden, diese Berechnung aufzuteilen, wobei jede spezifische Kommunikationsprimitive erfordert, um die Ergebnisse zu synchronisieren:
Strategie A: Spalten-lineare Parallelität
In diesem Schema teilen wir die Gewichtsmatrix entlang ihrer Spalten auf.
- Partitionierung: Wenn wir 2 GPUs haben, teilen wir in auf.
- AusfĂĽhrung: Wir replizieren die Eingabe auf beiden GPUs. GPU 1 berechnet und GPU 2 berechnet .
- Ergebnis: Jede GPU hält eine partielle Breite des Ausgaberesultats (z. B. die erste Hälfte der Ausgabe-Features).
- Kommunikation: Um die vollständige Ausgabe zu rekonstruieren, ist eine All-Gather-Operation erforderlich, um die Ergebnisse aller GPUs zu verketten.
Strategie B: Zeilen-lineare Parallelisierung
Hier teilen wir die Gewichtsmatrix entlang ihrer Zeilen auf.
- Partitionierung: wird in zwei Zeilenblöcke und (vertikal gestapelt) aufgeteilt.
- AusfĂĽhrung: Damit die Mathematik funktioniert, muss die Eingabe ebenfalls entlang ihrer letzten Dimension (Spalten) in aufgeteilt werden. GPU 1 berechnet .
- Ergebnis: Jede GPU hält eine Teilsumme des Endergebnisses. .
- Kommunikation: Um die gĂĽltige endgĂĽltige Ausgabe zu erhalten, mĂĽssen die Ergebnisse aller GPUs summiert werden. Dies erfordert eine All-Reduce-Operation.
Die "Megatron-LM"-Optimierung
Die Effizienz im TP ergibt sich aus der Kombination dieser beiden Strategien, um die Kommunikation zu minimieren. In einem Standard-Transformer-MLP-Block (Linear GeLU Linear) können wir die Aufteilungen so anordnen, dass eine Synchronisation in der Mitte vermieden wird.
- Layer 1 (Column Parallel): Die Gewichtsmatrix wird spaltenweise aufgeteilt. Die Ausgabe sind auf jeder GPU aufgeteilte Aktivierungen.
- Nichtlinearität (GeLU): Da GeLU eine elementweise Operation ist (), kann sie unabhängig auf die Teilergebnisse jeder GPU angewendet werden. Hier ist keine Kommunikation erforderlich.
- Layer 2 (Row Parallel): Die zweite Gewichtsmatrix wird zeilenweise aufgeteilt. Sie nimmt die aufgeteilte Ausgabe von Layer 1 direkt als aufgeteilten Input entgegen.
- Finale Synchronisation: Erst nach der zweiten Schicht fĂĽhren wir ein All-Reduce durch, um die Teilergebnisse zu summieren.
Durch diese Optimierung werden die Kommunikationsereignisse von zwei pro Block auf nur eines reduziert.
2. Bare Metal Implementierung: Reines PyTorch
Um genau zu verstehen, was auf der Hardware passiert, implementieren wir eine vereinfachte Row-Parallel Linear-Schicht unter Verwendung der grundlegenden torch.distributed-Primitiven. Dies umgeht höhere Abstraktionen und zeigt die Datenbewegung. Dieses Beispiel spiegelt die Logik wider, die in Community-Implementierungen zu finden ist.
import torch
import torch.nn as nn
import torch.distributed as dist
class RowParallelLinear(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
# 1. Setup World Info
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
# 2. Calculate Shard Size
# We split the input dimension (rows of the weight matrix) across GPUs
# Note: We assume input_size is divisible by world_size for simplicity
self.input_shard_size = input_size // self.world_size
# 3. Initialize Sharded Weights
# Each GPU only holds a fraction of the total weights!
# Shape: [output_size, input_shard_size]
# Note: PyTorch Linear weights are typically (out_features, in_features)
self.weight = nn.Parameter(torch.randn(output_size, self.input_shard_size))
# Bias is usually handled by one rank or replicated and reduced.
# Simplified here: no bias for clarity.
def forward(self, x):
# x input shape: [batch_size, input_size]
# 4. Scatter Input (Simulating the split)
# In a real Transformer, x might already be sharded from the previous
# Column-Parallel layer. Here we manually shard it to simulate the state.
# Split x along the last dimension (columns)
input_shards = list(x.chunk(self.world_size, dim=-1))
local_input = input_shards[self.rank].contiguous()
# 5. Local MatMul
# Compute the partial result for this GPU's slice of data
# local_output shape: [batch_size, output_size]
# We use .t() because PyTorch Linear weights are stored transposed
local_output = torch.matmul(local_input, self.weight.t())
# 6. All-Reduce (The Communication Bottleneck)
# Sum the partial results from all GPUs into every GPU
# This is a blocking operation! The GPU compute cores wait here.
dist.all_reduce(local_output, op=dist.ReduceOp.SUM)
return local_output
# Usage pseudo-code:
# dist.init_process_group(backend='nccl')
# layer = RowParallelLinear(1024, 512).cuda()
# y = layer(x)In diesem Code ist die Zeile dist.all_reduce der kritische Pfad. Die CUDA-Kerne sind untätig (oder "warten"), während der NCCL-Ring-Reduce-Algorithmus die Puffer zwischen den GPUs über das Interconnect weiterleitet.
3. Vor- und Nachteile: Eine Hardware-Perspektive
Vorteile
- Speicherreduktion: TP verteilt Modellparameter, Gradienten und Optimiererzustände auf (wobei die Anzahl der GPUs ist). Entscheidend ist, dass auch der Aktivierungsspeicher für die Matrixmultiplikationen verteilt wird, was den maximalen Speicherbedarf pro Gerät erheblich reduziert.
- Zugriff auf riesige Modelle: Es ermöglicht das Training von Modellen, bei denen die Gewichte einer einzelnen Schicht einfach zu groß sind, um in den VRAM einer einzelnen GPU zu passen.
- Reduzierte Latenz (im Vergleich zu Pipeline): Im Gegensatz zur Pipeline-Parallelisierung, die eine "Blase" von Leerlaufzeit einführt, während auf das Durchlaufen der Daten durch die Schichten gewartet wird, hält TP alle GPUs gleichzeitig aktiv (während der Rechenphase).
Nachteile
- Kommunikation auf dem kritischen Pfad: Dies ist der größte Nachteil. Bei TP findet die Kommunikation (All-Reduce) innerhalb der Vorwärts- und Rückwärtsdurchläufe jeder Schicht statt. Die GPU kann mit der nächsten Operation (wie LayerNorm) erst fortfahren, wenn der All-Reduce abgeschlossen ist. Dies unterbricht effektiv die Berechnung und verhindert die Überlappung von Berechnung und Kommunikation.
- Bandbreitenlimit: Aufgrund der hohen Synchronisationsfrequenz (zweimal pro Transformer-Schicht) ist TP stark von der Bandbreite der Verbindungen abhängig.
- Intra-Node (NVLink): TP funktioniert effizient innerhalb eines einzelnen Knotens (z. B. 8 GPUs), da NVLink eine enorme Bandbreite bereitstellt (z. B. 900 GB/s).
- Inter-Node (Ethernet/InfiniBand): Das Skalieren von TP über mehrere Knoten hinweg ist problematisch. Standard-Netzwerkgeschwindigkeiten sind um Größenordnungen langsamer als NVLink, wodurch die Kommunikationszeit die Berechnung dominiert. Benchmarks zeigen, dass der Durchsatz um etwa 43% sinkt, wenn TP von 8 auf 16 GPUs skaliert wird (Überschreiten der Knotengrenze).
- Implementierungskomplexität: Im Gegensatz zu FSDP, das Standard-PyTorch-Module umschließt, erfordert TP das Umschreiben des Modellierungscodes, um verteilte Gewichte zu handhaben und Kommunikationsprimitiven manuell einzufügen.
Ăśbersichtstabelle: Wann sollte TP verwendet werden?
| Scenario | Recommendation | Hardware Reason |
|---|---|---|
| Single Node (<= 8 GPUs) | Highly Recommended | NVLink bandwidth is sufficient to hide the synchronization cost. |
| Multi-Node (> 8 GPUs) | Avoid | Inter-node latency kills throughput. Use Data or Pipeline Parallelism instead. |
| Huge Weights | Required | If a layer doesn't fit in VRAM, TP is the only way to split the tensor itself. |
Die Rolle der 3D-Parallelisierung
Bei umfangreichen Trainingsläufen (wie Llama 3 oder GPT-4) ist Tensor Parallelism die "innerste" Schleife. Sie verwenden in der Regel Tensor Parallelism über die 8 GPUs innerhalb eines Knotens (um die massiven Gewichte unterzubringen) und kombinieren dies anschließend mit Pipeline Parallelism über verschiedene Knoten (um die Tiefe zu skalieren) sowie Data Parallelism (um die Batchgröße zu skalieren).