啟用 fp8 訓練後,“GPT-2 訓練時間”提升了 4.3%,現在只需 2.91 小時。另外值得一提的是,如果使用 8 倍 H100 實例價格,復現 GPT-2 的成本實際上只需約 20 美元。這令人振奮——
GPT-2(7 年前):發佈風險太大。
GPT-2(今天):新的 MNIST 數據集!:)
肯定能遠低於 1 小時。
關於 fp8,我再補充幾句。它比我預想的要複雜一些,我花了一段時間才最終決定採用它,即使現在,由於 fp8 的整體支持度較低,我仍然不能完全確定它是否是個好主意。理論上,H100 上的 fp8 浮點運算能力是 2 倍,但實際上卻遠低於此。在實際訓練過程中,我們並非完全受限於計算能力,額外的尺度轉換會帶來額外的開銷,GEMM 模型在 GPT-2 規模下還不夠大,不足以明顯抵消這些開銷,當然,精度越低,每一步的質量就越小。對於逐行縮放方案,FP8 和 BF16 的損失曲線非常接近,但網絡步進速度較慢。對於逐張縮放方案,損失曲線的差異更大(即每一步的質量都更差),但至少我們現在獲得了速度提升(約 7.3%)。你可以通過增加訓練週期(訓練更多步,但每一步速度更快)來簡單地恢復性能,並希望最終網絡性能能夠提升。在這種情況下,經過對這些方案和訓練週期的調整,目前我最終獲得了約 5% 的速度提升。 Torchao 在他們的論文中報告稱,Llama3-8B 的 FP8 訓練速度提升了 25%(相比之下,我未考慮模型容量的情況下提升了約 7.3%),這更接近我最初的預期,儘管 Llama3-8B 的模型規模要大得多。這可能並非 FP8 的終結。通過精確選擇應用 FP8 的層,並更謹慎地處理網絡中的數值,應該可以進一步提升性能。
twitter.com/karpathy/status/20...