Pytorch:單卡多程序並行訓練

2023-01-25 06:01:53

1 導引

我們在部落格《Python:多程序並行程式設計與程序池》中介紹瞭如何使用Python的multiprocessing模組進行並行程式設計。不過在深度學習的專案中,我們進行單機多程序程式設計時一般不直接使用multiprocessing模組,而是使用其替代品torch.multiprocessing模組。它支援完全相同的操作,但對其進行了擴充套件。

Python的multiprocessing模組可使用forkspawnforkserver三種方法來建立程序。但有一點需要注意的是,CUDA執行時不支援使用fork,我們可以使用spawnforkserver方法來建立子程序,以在子程序中使用CUDA。建立程序的方法可用multiprocessing.set_start_method(...) API來進行設定,比如下列程式碼就表示用spawn方法建立程序:

import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True) 

事實上,torch.multiprocessing在單機多程序程式設計中應用廣泛。尤其是在我們跑聯邦學習實驗時,常常需要在一張卡上並行訓練多個模型。注意,Pytorch多機分散式模組torch.distributed在單機上仍然需要手動fork程序。本文關注單卡多程序模型。

2 單卡多程序程式設計模型

我們在上一篇文章中提到過,多程序並行程式設計中最關鍵的一點就是程序間通訊。Python的multiprocessing採用共用記憶體進行程序間通訊。在我們的單卡多程序模型中,共用記憶體實際上可以直接由我們的CUDA記憶體擔任。

可能有讀者會表示不對啊,Pytorch中每個張量有一個tensor.share_memory_()用於將張量的資料移動到主機的共用記憶體中呀,如果CUDA記憶體直接擔任共用記憶體的作用,那要這個API幹啥呢?實際上,tensor.share_memory_()只在CPU模式下有使用的必要,如果張量分配在了CUDA上,這個函數實際上為空操作(no-op)。此外還需要注意,我們這裡的共用記憶體是程序間通訊的概念,注意與CUDA kernel層面的共用記憶體相區分。

注意,Python/Pytorch多程序模組的程序函數的引數和返回值必須相容於pickle編碼,任務的執行是在單獨的直譯器中完成的,進行程序間通訊時需要在不同的直譯器之間交換資料,此時必須要進行序列化處理。在機器學習中常使用的稀疏矩陣不能序列化,如果涉及稀疏矩陣的操作會發生異常: NotImplementedErrorCannot access storage of SparseTensorImpl,在多程序程式設計時需要轉換為稠密矩陣處理。

3 範例: 同步並行SGD演演算法

我們的範例採用在部落格《分散式機器學習:同步並行SGD演演算法的實現與複雜度分析(PySpark)》中所介紹的同步並行SGD演演算法。計算模式採用資料並行方式,即將資料進行劃分並分配到多個工作節點(Worker)上進行訓練。同步SGD演演算法的虛擬碼描述如下:

注意,我們此處的多程序共用記憶體,是無需劃分資料而各程序直接對共用記憶體進行非同步無鎖讀寫的(參考Hogwild!演演算法[3])。但是我們這裡為了演示同步並行SGD演演算法,還是為每個程序設定本地資料集和本地權重,且每個epoch各程序進行一次全域性同步,這樣也便於我們擴充套件到同步聯邦學習實驗環境。

在程式碼實現上,我們需要先對本地資料集進行劃,這裡需要繼承torch.utils.data.subset以自定義資料集類(參見我的部落格《Pytorch:自定義Subset/Dataset類完成資料集拆分 》):

class CustomSubset(Subset):
    '''A custom subset class with customizable data transformation'''
    def __init__(self, dataset, indices, subset_transform=None):
        super().__init__(dataset, indices)
        self.subset_transform = subset_transform

    def __getitem__(self, idx):
        x, y = self.dataset[self.indices[idx]]
        if self.subset_transform:
            x = self.subset_transform(x)
        return x, y   

    def __len__(self):
        return len(self.indices)

def dataset_split(dataset, n_workers):
    n_samples = len(dataset)
    n_sample_per_workers = n_samples // n_workers
    local_datasets = []
    for w_id in range(n_workers):
        if w_id < n_workers - 1:
            local_datasets.append(CustomSubset(dataset, range(w_id * n_sample_per_workers, (w_id + 1) * n_sample_per_workers)))
        else:
            local_datasets.append(CustomSubset(dataset, range(w_id * n_sample_per_workers, n_samples)))
    return local_datasets    

local_train_datasets = dataset_split(train_dataset, n_workers) 

然後定義本地模型、全域性模型和本地權重、全域性權重:

local_models = [Net().to(device) for i in range(n_workers)]
global_model = Net().to(device)
local_Ws = [{key: value for key, value in local_models[i].named_parameters()} for i in range(n_workers)]
global_W = {key: value for key, value in global_model.named_parameters()}

然後由於是同步演演算法,我們需要初始化多程序同步屏障:

from torch.multiprocessing import Barrier
synchronizer = Barrier(n_workers)

訓練演演算法流程(含測試部分)描述如下:

for epoch in range(epochs):
    for rank in range(n_workers):
        # pull down global model to local
        pull_down(global_W, local_Ws, n_workers)
        
        processes = []
        for rank in range(n_workers):
            p = mp.Process(target=train_epoch, args=(epoch, rank, local_models[rank], device,
                                            local_train_datasets[rank], synchronizer, kwargs))
            # We first train the model across `num_processes` processes
            p.start()
            processes.append(p)
                        
        for p in processes:
            p.join()
        
        test(global_model, device, test_dataset, kwargs)

        # init the global model
        init(global_W)
        aggregate(global_W, local_Ws, n_workers)

# Once training is complete, we can test the model
test(global_model, device, test_dataset, kwargs)

其中的pull_down()函數負責將全域性模型賦給本地模型:

def pull_down(global_W, local_Ws, n_workers):
    # pull down global model to local
    for rank in range(n_workers):
        for name, value in local_Ws[rank].items():
            local_Ws[rank][name].data = global_W[name].data 

init()函數負責給全域性模型進行初始化:

def init(global_W):
    # init the global model
    for name, value in global_W.items():
        global_W[name].data  = torch.zeros_like(value)

aggregate()函數負責對本地模型進行聚合(這裡我們採用最簡單的平均聚合方式):

def aggregate(global_W, local_Ws, n_workers):
    for rank in range(n_workers):
        for name, value in local_Ws[rank].items():
            global_W[name].data += value.data

    for name in local_Ws[rank].keys():
        global_W[name].data /= n_workers

最後,train_epochtest_epoch定義如下(注意train_epoch函數的結尾需要加上 synchronizer.wait()表示程序間同步):

def train_epoch(epoch, rank, local_model, device, dataset, synchronizer, dataloader_kwargs):
    torch.manual_seed(seed + rank)
    train_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
    optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum)

    local_model.train()
    pid = os.getpid()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = local_model(data.to(device))
        loss = F.nll_loss(output, target.to(device))
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                pid, epoch + 1, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
    synchronizer.wait()
    
    
def test(epoch, model, device, dataset, dataloader_kwargs):
    torch.manual_seed(seed)
    test_loader = torch.utils.data.DataLoader(dataset, **dataloader_kwargs)

    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.to(device))
            test_loss += F.nll_loss(output, target.to(device), reduction='sum').item() # sum up batch loss
            pred = output.max(1)[1] # get the index of the max log-probability
            correct += pred.eq(target.to(device)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest Epoch: {} Global loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        epoch + 1, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))   

我們在epochs=3n_workers=4的設定下執行結果如下圖所示(我們這裡僅展示每個epoch同步通訊後,使用測試集對全域性模型進行測試的結果):

Test Epoch: 1 Global loss: 0.0858, Accuracy: 9734/10000 (97%)
Test Epoch: 2 Global loss: 0.0723, Accuracy: 9794/10000 (98%)
Test Epoch: 3 Global loss: 0.0732, Accuracy: 9796/10000 (98%)

可以看到測試結果是趨於收斂的。
最後,完整程式碼我已經上傳到了GitHub倉庫 [Distributed-Algorithm-PySpark]
,感興趣的童鞋可以前往檢視。

參考