{"id":178033,"date":"2026-01-20T14:37:50","date_gmt":"2026-01-20T14:37:50","guid":{"rendered":"https:\/\/ktromedia.com\/?p=178033"},"modified":"2026-01-20T14:37:50","modified_gmt":"2026-01-20T14:37:50","slug":"training-a-model-with-limited-memory-using-mixed-precision-and-gradient-checkpointing","status":"publish","type":"post","link":"http:\/\/ktromedia.com\/?p=178033","title":{"rendered":"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing"},"content":{"rendered":"<div id=\"\">\n<p>Training a language model is memory-intensive, not only because the model itself is large but also because training data batches often contain long sequences. Training a model with limited memory is challenging. In this article, you will learn techniques that enable model training in memory-constrained environments. In particular, you will learn about:<\/p>\n<ul>\n<li>Low-precision floating-point numbers and mixed-precision training<\/li>\n<li>Using gradient checkpointing<\/li>\n<\/ul>\n<p>Let\u2019s get started!<\/p>\n<div id=\"attachment_22894\" style=\"width: 810px\" class=\"wp-caption aligncenter\"><\/p>\n<p id=\"caption-attachment-22894\" class=\"wp-caption-text\">Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing<br \/>Photo by <a href=\"https:\/\/unsplash.com\/photos\/river-between-brown-concrete-buildings-PdnseHuDFZU\">Meduana<\/a>. Some rights reserved.<\/p>\n<\/div>\n<h2>Overview<\/h2>\n<p>This article is divided into three parts; they are:<\/p>\n<ul>\n<li>Floating-point Numbers<\/li>\n<li>Automatic Mixed Precision Training<\/li>\n<li>Gradient Checkpointing<\/li>\n<\/ul>\n<p>Let\u2019s get started!<\/p>\n<h2>Floating-Point Numbers<\/h2>\n<p>The default data type in PyTorch is the IEEE 754 32-bit floating-point format, also known as single precision. It is not the only floating-point type you can use. For example, most CPUs support 64-bit double-precision floating-point, and GPUs often support half-precision floating-point as well. The table below lists some floating-point types:<\/p>\n<div style=\"overflow: scroll; white-space: nowrap;\">\n<table>\n<thead>\n<tr>\n<th>Data Type<\/th>\n<th>PyTorch Type<\/th>\n<th>Total Bits<\/th>\n<th>Sign Bit<\/th>\n<th>Exponent Bits<\/th>\n<th>Mantissa Bits<\/th>\n<th>Min Value<\/th>\n<th>Max Value<\/th>\n<th>eps<\/th>\n<\/tr>\n<\/thead>\n<tbody>\n<tr>\n<td>IEEE 754 double precision<\/td>\n<td><code>torch.float64<\/code><\/td>\n<td>64<\/td>\n<td>1<\/td>\n<td>11<\/td>\n<td>52<\/td>\n<td>-1.79769e+308<\/td>\n<td>1.79769e+308<\/td>\n<td>2.22045e-16<\/td>\n<\/tr>\n<tr>\n<td>IEEE 754 single precision<\/td>\n<td><code>torch.float32<\/code><\/td>\n<td>32<\/td>\n<td>1<\/td>\n<td>8<\/td>\n<td>23<\/td>\n<td>-3.40282e+38<\/td>\n<td>3.40282e+38<\/td>\n<td>1.19209e-07<\/td>\n<\/tr>\n<tr>\n<td>IEEE 754 half precision<\/td>\n<td><code>torch.float16<\/code><\/td>\n<td>16<\/td>\n<td>1<\/td>\n<td>5<\/td>\n<td>10<\/td>\n<td>-65504<\/td>\n<td>65504<\/td>\n<td>0.000976562<\/td>\n<\/tr>\n<tr>\n<td>bf16<\/td>\n<td><code>torch.bfloat16<\/code><\/td>\n<td>16<\/td>\n<td>1<\/td>\n<td>8<\/td>\n<td>7<\/td>\n<td>-3.38953e+38<\/td>\n<td>3.38953e+38<\/td>\n<td>0.0078125<\/td>\n<\/tr>\n<tr>\n<td>fp8 (e4m3)<\/td>\n<td><code>torch.float8_e4m3fn<\/code><\/td>\n<td>8<\/td>\n<td>1<\/td>\n<td>4<\/td>\n<td>3<\/td>\n<td>-448<\/td>\n<td>448<\/td>\n<td>0.125<\/td>\n<\/tr>\n<tr>\n<td>fp8 (e5m2)<\/td>\n<td><code>torch.float8_e5m2<\/code><\/td>\n<td>8<\/td>\n<td>1<\/td>\n<td>5<\/td>\n<td>2<\/td>\n<td>-57344<\/td>\n<td>57344<\/td>\n<td>0.25<\/td>\n<\/tr>\n<tr>\n<td>fp8 (e8m0)<\/td>\n<td><code>torch.float8_e8m0fnu<\/code><\/td>\n<td>8<\/td>\n<td>1<\/td>\n<td>8<\/td>\n<td>0<\/td>\n<td>1.70141e+38<\/td>\n<td>5.87747e-39<\/td>\n<td>1.0<\/td>\n<\/tr>\n<tr>\n<td>fp6 (e3m2)<\/td>\n<td\/>\n<td>6<\/td>\n<td>1<\/td>\n<td>3<\/td>\n<td>2<\/td>\n<td>-28<\/td>\n<td>28<\/td>\n<td>0.25<\/td>\n<\/tr>\n<tr>\n<td>fp6 (e2m3)<\/td>\n<td\/>\n<td>6<\/td>\n<td>1<\/td>\n<td>2<\/td>\n<td>3<\/td>\n<td>-7.5<\/td>\n<td>7.5<\/td>\n<td>0.125<\/td>\n<\/tr>\n<tr>\n<td>fp4 (e2m1)<\/td>\n<td\/>\n<td>4<\/td>\n<td>1<\/td>\n<td>2<\/td>\n<td>1<\/td>\n<td>-6<\/td>\n<td>6<\/td>\n<td\/>\n<\/tr>\n<\/tbody>\n<\/table>\n<\/div>\n<p>Floating-point numbers are binary representations of real numbers. Each consists of a sign bit, several bits for the exponent, and several bits for the mantissa. They are laid out as shown in the figure below. When sorted by their binary representation, floating-point numbers retain their order by real-number value.<\/p>\n<div id=\"attachment_22893\" style=\"width: 760px\" class=\"wp-caption aligncenter\"><img fetchpriority=\"high\" decoding=\"async\" aria-describedby=\"caption-attachment-22893\" class=\"wp-image-22893\" src=\"https:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.png\" alt=\"\" width=\"750\" height=\"96\"  \/><\/p>\n<p id=\"caption-attachment-22893\" class=\"wp-caption-text\">Floating-point number representation. Figure from <a href=\"https:\/\/commons.wikimedia.org\/wiki\/File:Float_example.svg\">Wikimedia<\/a>.<\/p>\n<\/div>\n<p>Different floating-point types have different ranges and precisions. Not all types are supported by all hardware. For example, fp4 is only supported in Nvidia\u2019s Blackwell architecture. PyTorch supports only a few data types. You can run the following code to print information about various floating-point types:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd0b717632508\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\nimport torch&#13;<br \/>\nfrom tabulate import tabulate&#13;<br \/>\n&#13;<br \/>\n# float types:&#13;<br \/>\nfloat_types = [&#13;<br \/>\n    torch.float64,&#13;<br \/>\n    torch.float32,&#13;<br \/>\n    torch.float16,&#13;<br \/>\n    torch.bfloat16,&#13;<br \/>\n    torch.float8_e4m3fn,&#13;<br \/>\n    torch.float8_e5m2,&#13;<br \/>\n    torch.float8_e8m0fnu,&#13;<br \/>\n]&#13;<br \/>\n&#13;<br \/>\n# collect finfo for each type&#13;<br \/>\ntable = []&#13;<br \/>\nfor dtype in float_types:&#13;<br \/>\n    info = torch.finfo(dtype)&#13;<br \/>\n    try:&#13;<br \/>\n        typename = info.dtype&#13;<br \/>\n    except:&#13;<br \/>\n        typename = str(dtype)&#13;<br \/>\n    table.append([typename, info.max, info.min, info.smallest_normal, info.eps])&#13;<br \/>\n&#13;<br \/>\nheaders = [&#8216;data type&#8217;, &#8216;max&#8217;, &#8216;min&#8217;, &#8216;smallest normal&#8217;, &#8216;eps&#8217;]&#13;<br \/>\nprint(tabulate(table, headers=headers))<\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<div class=\"urvanov-syntax-highlighter-nums-content\" style=\"font-size: 12px !important; line-height: 15px !important;\">\n<p>1<\/p>\n<p>2<\/p>\n<p>3<\/p>\n<p>4<\/p>\n<p>5<\/p>\n<p>6<\/p>\n<p>7<\/p>\n<p>8<\/p>\n<p>9<\/p>\n<p>10<\/p>\n<p>11<\/p>\n<p>12<\/p>\n<p>13<\/p>\n<p>14<\/p>\n<p>15<\/p>\n<p>16<\/p>\n<p>17<\/p>\n<p>18<\/p>\n<p>19<\/p>\n<p>20<\/p>\n<p>21<\/p>\n<p>22<\/p>\n<p>23<\/p>\n<p>24<\/p>\n<p>25<\/p>\n<p>26<\/p>\n<\/div>\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-e\">import <\/span><span class=\"crayon-e\">torch<\/span><\/p>\n<p><span class=\"crayon-e\">from <\/span><span class=\"crayon-e\">tabulate <\/span><span class=\"crayon-e\">import <\/span><span class=\"crayon-i\">tabulate<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># float types:<\/span><\/p>\n<p><span class=\"crayon-v\">float_types<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-sy\">[<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float64<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float32<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float16<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">bfloat16<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float8_e4m3fn<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float8_e5m2<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float8_e8m0fnu<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-sy\">]<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># collect finfo for each type<\/span><\/p>\n<p><span class=\"crayon-v\">table<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-sy\">]<\/span><\/p>\n<p><span class=\"crayon-st\">for<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">dtype <\/span><span class=\"crayon-st\">in<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">float_types<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">info<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">finfo<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-st\">try<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">typename<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">info<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">dtype<\/span><\/p>\n<p><span class=\"crayon-e\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">except<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">typename<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">str<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">table<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">append<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-v\">typename<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">info<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">max<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">info<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">min<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">info<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">smallest_normal<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">info<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">eps<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-v\">headers<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-s\">&#8216;data type&#8217;<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-s\">&#8216;max&#8217;<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-s\">&#8216;min&#8217;<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-s\">&#8216;smallest normal&#8217;<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-s\">&#8216;eps&#8217;<\/span><span class=\"crayon-sy\">]<\/span><\/p>\n<p><span class=\"crayon-e\">print<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-e\">tabulate<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">table<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">headers<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">headers<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<p>Pay attention to the min and max values for each type, as well as the eps value. The min and max values indicate the range a type can support (the <strong>dynamic range<\/strong>). If you train a model with such a type, but the model weights exceed this range, you will get overflow or underflow, usually causing the model to output NaN or Inf. The eps value is the smallest positive number such that the type can differentiate between <code>1+eps<\/code> and <code>1<\/code>. This is a metric for precision. If your model\u2019s gradient updates are smaller than eps, you will likely observe the vanishing gradient problem.<\/p>\n<p>Therefore, <code>float32<\/code> is a good default choice for deep learning: it has a wide dynamic range and high precision. However, each <code>float32<\/code> number requires 4 bytes of memory. As a compromise, you can use <code>float16<\/code> to save memory, but you are likely to encounter overflow or underflow issues since the dynamic range is much smaller.<\/p>\n<p>The Google Brain team identified this problem and proposed <code>bfloat16<\/code>, a 16-bit floating-point format with the same dynamic range as <code>float32<\/code>. As a trade-off, the precision is an order of magnitude worse than <code>float16<\/code>. It turns out that dynamic range is more important than precision for deep learning, making <code>bfloat16<\/code> highly useful.<\/p>\n<p>When you create a tensor in PyTorch, you can specify the data type. For example:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd16099809690\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\nx = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)&#13;<br \/>\nprint(x)<\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-v\">x<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">tensor<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-cn\">1.0<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-cn\">2.0<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-cn\">3.0<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float16<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-e\">print<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">x<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<p>There is a straightforward way to change the default to a different type, such as <code>bfloat16<\/code>. This is handy for model training. All you need to do is set the following line before you create any model or optimizer:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd19180150918\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\n# set default dtype to bfloat16&#13;<br \/>\ntorch.set_default_dtype(torch.bfloat16)<\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-p\"># set default dtype to bfloat16<\/span><\/p>\n<p><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">set_default_dtype<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">bfloat16<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<p>Just by doing this, you force all your model weights and gradients to be <code>bfloat16<\/code> type. This saves half of the memory. In the <a href=\"https:\/\/machinelearningmastery.com\/pretraining-a-llama-model-on-your-local-gpu\/\">previous article<\/a>, you were advised to set the batch size to 8 to fit a GPU with only 12GB of VRAM. With <code>bfloat16<\/code>, you should be able to set the batch size to 16.<\/p>\n<p>Note that using 8-bit float or lower-precision types may not work. This is because you need hardware support and PyTorch to perform the corresponding mathematical operations. You can try the following code (requires a CUDA device) and find that you will need extra effort to operate on 8-bit float:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd1c319033012\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\ndtype = torch.float8_e4m3fn&#13;<br \/>\n&#13;<br \/>\n# Define a tensor with float8 will see&#13;<br \/>\n# NotImplementedError: &#8220;normal_kernel_cuda&#8221; not implemented for &#8216;Float8_e4m3fn&#8217;&#13;<br \/>\nx = torch.randn(16, 16, dtype=dtype, device=&#8221;cuda&#8221;)&#13;<br \/>\n&#13;<br \/>\n# Create in float32 and convert to float8 works&#13;<br \/>\nx = torch.randn(16, 16, device=&#8221;cuda&#8221;).to(dtype)&#13;<br \/>\n&#13;<br \/>\n# But matmul is not supported. You will see&#13;<br \/>\n# NotImplementedError: &#8220;addmm_cuda&#8221; not implemented for &#8216;Float8_e4m3fn&#8217;&#13;<br \/>\ny = x @ x.T&#13;<br \/>\n&#13;<br \/>\n# The correct way to run matrix multiplication on 8-bit float&#13;<br \/>\ny = torch._scaled_mm(x, x.T, out_dtype=dtype,&#13;<br \/>\nscale_a=torch.tensor(1.0, device=&#8221;cuda&#8221;),&#13;<br \/>\nscale_b=torch.tensor(1.0, device=&#8221;cuda&#8221;))&#13;<br \/>\nprint(y)<\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<div class=\"urvanov-syntax-highlighter-nums-content\" style=\"font-size: 12px !important; line-height: 15px !important;\">\n<p>1<\/p>\n<p>2<\/p>\n<p>3<\/p>\n<p>4<\/p>\n<p>5<\/p>\n<p>6<\/p>\n<p>7<\/p>\n<p>8<\/p>\n<p>9<\/p>\n<p>10<\/p>\n<p>11<\/p>\n<p>12<\/p>\n<p>13<\/p>\n<p>14<\/p>\n<p>15<\/p>\n<p>16<\/p>\n<p>17<\/p>\n<p>18<\/p>\n<\/div>\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">float8<\/span><span class=\"crayon-sy\">_<\/span>e4m3fn<\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># Define a tensor with float8 will see<\/span><\/p>\n<p><span class=\"crayon-p\"># NotImplementedError: &#8220;normal_kernel_cuda&#8221; not implemented for &#8216;Float8_e4m3fn&#8217;<\/span><\/p>\n<p><span class=\"crayon-v\">x<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">randn<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-cn\">16<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-cn\">16<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">device<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># Create in float32 and convert to float8 works<\/span><\/p>\n<p><span class=\"crayon-v\">x<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">randn<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-cn\">16<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-cn\">16<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">device<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-st\">to<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># But matmul is not supported. You will see<\/span><\/p>\n<p><span class=\"crayon-p\"># NotImplementedError: &#8220;addmm_cuda&#8221; not implemented for &#8216;Float8_e4m3fn&#8217;<\/span><\/p>\n<p><span class=\"crayon-v\">y<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-i\">x<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-sy\">@<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">x<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-i\">T<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># The correct way to run matrix multiplication on 8-bit float<\/span><\/p>\n<p><span class=\"crayon-v\">y<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">_scaled_mm<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">x<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">x<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">T<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">out_dtype<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-v\">scale_a<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">tensor<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-cn\">1.0<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">device<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-v\">scale_b<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">tensor<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-cn\">1.0<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">device<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-e\">print<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">y<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<h2>Automatic Mixed Precision Training<\/h2>\n<p>Training a model with float16 may encounter issues because not all operations should be performed at lower precision. For example, matrix multiplication is robust in lower precision, but reduction operations, pooling, and some activation functions require float32.<\/p>\n<p>You can set the data type manually for each component of your model, but this is tedious because you must convert data types between components. A better solution is to use <strong>automatic mixed precision<\/strong>\u00a0training in PyTorch.<\/p>\n<p>PyTorch has a sub-library <code>torch.amp<\/code> that can automatically cast the data type based on the operation. Not all operations are carried out in the same floating-point type. If the operation is known to be robust\u00a0<span style=\"box-sizing: border-box; margin: 0px; padding: 0px;\">at lower precision, this library will\u00a0<strong>cast<\/strong> the tensors to that<\/span> precision before running the operation. Hence the name \u201cmixed precision\u201d. Using lower precision may not only save memory but also speed up training. Some GPUs can run float16 operations at twice the speed of float32.<\/p>\n<p>When you train a model with <code>torch.amp<\/code>, all you need to do is run your forward pass under the context of <code>torch.amp.autocast()<\/code>. Typically, you will also use a <code>GradScaler<\/code> to handle gradient scaling. This is necessary because under low precision, you may encounter vanishing gradients due to the limited precision of your floating-point type. The <code>GradScaler<\/code> scales the gradient before the backward pass to prevent loss of gradient flow. During the backward pass, you should scale the gradient back for accurate updates. This process can be cumbersome because you need to determine the correct scale factor, which the <code>GradScaler<\/code> handles for you.<\/p>\n<p>Compared to the training loop from the <a href=\"https:\/\/machinelearningmastery.com\/pretraining-a-llama-model-on-your-local-gpu\/\">previous article<\/a>, below is how you typically use <code>torch.amp<\/code> to train a model:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd1f959350087\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\n&#8230;&#13;<br \/>\n&#13;<br \/>\n# Check if mixed precision training is supported&#13;<br \/>\nassert torch.amp.autocast_mode.is_autocast_available(&#8220;cuda&#8221;)&#13;<br \/>\n&#13;<br \/>\n# Creates a GradScaler before the training loop&#13;<br \/>\nscaler = torch.amp.GradScaler(&#8220;cuda&#8221;, enabled=True)&#13;<br \/>\n&#13;<br \/>\n# start training&#13;<br \/>\nfor epoch in range(begin_epoch, epochs):&#13;<br \/>\n    pbar = tqdm.tqdm(dataloader, desc=f&#8221;Epoch {epoch+1}\/{epochs}&#8221;)&#13;<br \/>\n    for batch_id, batch in enumerate(pbar):&#13;<br \/>\n        # get batched data&#13;<br \/>\n        input_ids, target_ids = batch&#13;<br \/>\n        # create attention mask: causal mask + padding mask&#13;<br \/>\n        attn_mask = create_causal_mask(input_ids.shape[1], device) + \\&#13;<br \/>\n                    create_padding_mask(input_ids, PAD_TOKEN_ID, device)&#13;<br \/>\n        # with autocasting to bfloat16, run the forward pass&#13;<br \/>\n        with torch.autocast(device_type=&#8221;cuda&#8221;, dtype=torch.bfloat16):&#13;<br \/>\n            logits = model(input_ids, attn_mask)&#13;<br \/>\n            loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1))&#13;<br \/>\n        # backward with loss, scaled by the GradScaler&#13;<br \/>\n        optimizer.zero_grad()&#13;<br \/>\n        scaler.scale(loss).backward()&#13;<br \/>\n        # step the optimizer and check if the scale has been updated&#13;<br \/>\n        scaler.step(optimizer)&#13;<br \/>\n        old_scale = scaler.get_scale()&#13;<br \/>\n        scaler.update()&#13;<br \/>\n        if scaler.get_scale() <\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<div class=\"urvanov-syntax-highlighter-nums-content\" style=\"font-size: 12px !important; line-height: 15px !important;\">\n<p>1<\/p>\n<p>2<\/p>\n<p>3<\/p>\n<p>4<\/p>\n<p>5<\/p>\n<p>6<\/p>\n<p>7<\/p>\n<p>8<\/p>\n<p>9<\/p>\n<p>10<\/p>\n<p>11<\/p>\n<p>12<\/p>\n<p>13<\/p>\n<p>14<\/p>\n<p>15<\/p>\n<p>16<\/p>\n<p>17<\/p>\n<p>18<\/p>\n<p>19<\/p>\n<p>20<\/p>\n<p>21<\/p>\n<p>22<\/p>\n<p>23<\/p>\n<p>24<\/p>\n<p>25<\/p>\n<p>26<\/p>\n<p>27<\/p>\n<p>28<\/p>\n<p>29<\/p>\n<p>30<\/p>\n<p>31<\/p>\n<p>32<\/p>\n<p>33<\/p>\n<\/div>\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># Check if mixed precision training is supported<\/span><\/p>\n<p><span class=\"crayon-st\">assert<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">amp<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">autocast_mode<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">is_autocast_available<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># Creates a GradScaler before the training loop<\/span><\/p>\n<p><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">amp<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">GradScaler<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">enabled<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-t\">True<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># start training<\/span><\/p>\n<p><span class=\"crayon-st\">for<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">epoch <\/span><span class=\"crayon-st\">in<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">range<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">begin_epoch<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">epochs<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">pbar<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">tqdm<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">tqdm<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">dataloader<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">desc<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-i\">f<\/span><span class=\"crayon-s\">&#8220;Epoch {epoch+1}\/{epochs}&#8221;<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-st\">for<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">batch_id<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">batch <\/span><span class=\"crayon-st\">in<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">enumerate<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">pbar<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># get batched data<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">input_ids<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">target_ids<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-i\">batch<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># create attention mask: causal mask + padding mask<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">attn_mask<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">create_causal_mask<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">input_ids<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">shape<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-cn\">1<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">device<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">+<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-sy\">\\<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-e\">create_padding_mask<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">input_ids<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">PAD_TOKEN_ID<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">device<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># with autocasting to bfloat16, run the forward pass<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-e\">with <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">autocast<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">device_type<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-s\">&#8220;cuda&#8221;<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">dtype<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">bfloat16<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">logits<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">model<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">input_ids<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">attn_mask<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">loss<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">loss_fn<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">logits<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">view<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-o\">&#8211;<\/span><span class=\"crayon-cn\">1<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">logits<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">size<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-o\">&#8211;<\/span><span class=\"crayon-cn\">1<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">target_ids<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">view<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-o\">&#8211;<\/span><span class=\"crayon-cn\">1<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># backward with loss, scaled by the GradScaler<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">zero_grad<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">scale<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">loss<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">backward<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># step the optimizer and check if the scale has been updated<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">step<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">old_scale<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">get_scale<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">update<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-st\">if<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">get_scale<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\"><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">old_scale<\/span><span class=\"crayon-o\">:<\/span><\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">scheduler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">step<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">pbar<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">set_postfix<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">loss<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-v\">loss<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">item<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">pbar<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">update<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-cn\">1<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">pbar<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">close<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<p>Using AMP autocasting is straightforward: keep the model\u2019s default precision at float32, then wrap the forward pass and loss computation with <code>torch.autocast()<\/code>. Under this context, all supported operations will run in the specified data type.<\/p>\n<p>Once you have the loss, let the <code>GradScaler<\/code> handle the backward pass. It will scale up the loss and update the model\u2019s gradients. However, this may cause issues if the scaling is too large, resulting in NaN or Inf gradients. Therefore, use <code>scaler.step(optimizer)<\/code> to step the optimizer, which verifies the gradients before executing the optimizer step. If <code>GradScaler<\/code> decides not to step the optimizer, it will reduce the scale factor when <code>update()<\/code> is called. Check whether the scale has been updated to determine if you should step the scheduler.<\/p>\n<p>Since the backward pass uses scaled loss, if you use gradient clipping, you should <strong>unscale<\/strong>\u00a0the gradients before clipping. Here\u2019s how to do it:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd27241482935\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\n&#8230;&#13;<br \/>\n# backward with loss, scaled by the GradScaler&#13;<br \/>\noptimizer.zero_grad()&#13;<br \/>\nscaler.scale(loss).backward()&#13;<br \/>\n# unscaled the gradients and apply gradient clipping&#13;<br \/>\nscaler.unscale_(optimizer)&#13;<br \/>\ntorch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)&#13;<br \/>\n# step the optimizer and check if the scale has been updated&#13;<br \/>\nscaler.step(optimizer)&#13;<br \/>\nold_scale = scaler.get_scale()&#13;<br \/>\nscaler.update()&#13;<br \/>\nif scaler.get_scale() <\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><\/p>\n<p><span class=\"crayon-p\"># backward with loss, scaled by the GradScaler<\/span><\/p>\n<p><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">zero_grad<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">scale<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">loss<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">backward<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-p\"># unscaled the gradients and apply gradient clipping<\/span><\/p>\n<p><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">unscale_<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">nn<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">utils<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">clip_grad_norm_<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">model<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">parameters<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-cn\">1.0<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-p\"># step the optimizer and check if the scale has been updated<\/span><\/p>\n<p><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">step<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">old_scale<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">get_scale<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">update<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-st\">if<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">get_scale<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\"><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">old_scale<\/span><span class=\"crayon-o\">:<\/span><\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">scheduler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">step<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<p>Normally, you don\u2019t need to call <code>scaler.unscale_()<\/code> manually since it\u2019s part of the <code>scaler.step(optimizer)<\/code> call. However, you must do so when applying gradient clipping so that the clipping function can observe the actual gradients.<\/p>\n<p>Autocasting is automatic, but the <code>GradScaler<\/code> maintains a state to track the scale factor. Therefore, when you checkpoint your model, you should also save the <code>scaler.state_dict()<\/code>, just as you would save the optimizer state:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd2c296539992\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\n&#8230;&#13;<br \/>\n# Loading checkpoint&#13;<br \/>\ncheckpoint = torch.load(&#8220;training_checkpoint.pth&#8221;)&#13;<br \/>\nmodel.load_state_dict(checkpoint[&#8220;model&#8221;])&#13;<br \/>\noptimizer.load_state_dict(checkpoint[&#8220;optimizer&#8221;])&#13;<br \/>\nscheduler.load_state_dict(checkpoint[&#8220;scheduler&#8221;])&#13;<br \/>\nscaler.load_state_dict(checkpoint[&#8220;scaler&#8221;])&#13;<br \/>\n&#13;<br \/>\n# Saving checkpoint&#13;<br \/>\ntorch.save({&#13;<br \/>\n    &#8220;model&#8221;: model.state_dict(),&#13;<br \/>\n    &#8220;optimizer&#8221;: optimizer.state_dict(),&#13;<br \/>\n    &#8220;scheduler&#8221;: scheduler.state_dict(),&#13;<br \/>\n    &#8220;scaler&#8221;: scaler.state_dict(),&#13;<br \/>\n}, f&#8221;training_checkpoint.pth&#8221;)<\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><\/p>\n<p><span class=\"crayon-p\"># Loading checkpoint<\/span><\/p>\n<p><span class=\"crayon-v\">checkpoint<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">load<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-s\">&#8220;training_checkpoint.pth&#8221;<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">model<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">load_state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">checkpoint<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-s\">&#8220;model&#8221;<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">load_state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">checkpoint<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-s\">&#8220;optimizer&#8221;<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">scheduler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">load_state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">checkpoint<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-s\">&#8220;scheduler&#8221;<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">load_state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">checkpoint<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-s\">&#8220;scaler&#8221;<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-p\"># Saving checkpoint<\/span><\/p>\n<p><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">save<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">{<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-s\">&#8220;model&#8221;<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">model<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-s\">&#8220;optimizer&#8221;<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">optimizer<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-s\">&#8220;scheduler&#8221;<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">scheduler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-s\">&#8220;scaler&#8221;<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">scaler<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">state_dict<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-sy\">}<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-i\">f<\/span><span class=\"crayon-s\">&#8220;training_checkpoint.pth&#8221;<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<h2>Gradient Checkpointing<\/h2>\n<p>When you train a model with half precision, you use half the memory compared to 32-bit float. With mixed-precision training, you may use slightly more memory because not all operations run at lower precision.<\/p>\n<p>If you still encounter memory issues, another technique that trades time for memory is <strong>gradient checkpointing<\/strong>. Recall that in deep learning, for a function $y=f(\\mathbb{u})$ and $\\mathbb{u}=g(\\mathbb{x}))$, then<\/p>\n<p>$$<br \/>\\frac{\\partial y}{\\partial \\mathbb{x}} = \\big(\\frac{\\partial \\mathbb{u}}{\\partial \\mathbb{x}}\\big)^\\top \\frac{\\partial y}{\\partial \\mathbb{u}}<br \/>$$<\/p>\n<p>where $y$ is a scalar (usually the loss metric), and $\\mathbb{u}$ and $\\mathbb{x}$ are vectors. The term $\\frac{\\partial \\mathbb{u}}{\\partial \\mathbb{x}}$ is the Jacobian matrix of $\\mathbb{u}$ with respect to $\\mathbb{x}$.<\/p>\n<p>The gradient $\\frac{\\partial y}{\\partial \\mathbb{x}}$ is needed to update $\\mathbb{x}$ but depends on $\\frac{\\partial y}{\\partial \\mathbb{u}}$. Normally, when you run the forward pass, all intermediate results such as $\\mathbb{u}$ are stored in memory so that when you run the backward pass, you can readily compute the gradient $\\frac{\\partial y}{\\partial \\mathbb{u}}$. However, this requires substantial memory for deep networks.<\/p>\n<p>Gradient checkpointing discards some intermediate results. As long as you know $\\mathbb{u}=g(\\mathbb{x})$, you can recompute $\\mathbb{u}$ from $\\mathbb{x}$ during the backward pass. This way, you don\u2019t need to store $\\mathbb{u}$ in memory, but you must compute $\\mathbb{u}$ twice: once for the forward pass and once for the backward pass.<\/p>\n<p>You can decide which intermediate results to discard. Applying gradient checkpointing to every two operations still requires storing many intermediate results. Applying it to larger blocks saves more memory.<\/p>\n<p>Referring to the model from the <a href=\"https:\/\/machinelearningmastery.com\/creating-a-llama-or-gpt-model-for-next-token-prediction\/\">previous article<\/a>, you can wrap every transformer block with gradient checkpointing:<\/p>\n<div id=\"urvanov-syntax-highlighter-69646924ccd30227228064\" class=\"urvanov-syntax-highlighter-syntax crayon-theme-classic urvanov-syntax-highlighter-font-monaco urvanov-syntax-highlighter-os-pc print-yes notranslate\" data-settings=\" minimize scroll-mouseover disable-anim\" style=\" margin-top: 12px; margin-bottom: 12px; font-size: 12px !important; line-height: 15px !important;\">\n<p><textarea wrap=\"soft\" class=\"urvanov-syntax-highlighter-plain print-no\" data-settings=\"dblclick\" readonly=\"readonly\" style=\"-moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4; font-size: 12px !important; line-height: 15px !important;\"><br \/>\n&#8230;&#13;<br \/>\nclass LlamaModel(nn.Module):&#13;<br \/>\n    def __init__(self, config: LlamaConfig) -&gt; None:&#13;<br \/>\n        super().__init__()&#13;<br \/>\n        self.rotary_emb = RotaryPositionEncoding(&#13;<br \/>\n            config.hidden_size \/\/ config.num_attention_heads,&#13;<br \/>\n            config.max_position_embeddings,&#13;<br \/>\n        )&#13;<br \/>\n&#13;<br \/>\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)&#13;<br \/>\n        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])&#13;<br \/>\n        self.norm = nn.RMSNorm(config.hidden_size, eps=1e-5)&#13;<br \/>\n&#13;<br \/>\n    def forward(self, input_ids: Tensor, attn_mask: Tensor) -&gt; Tensor:&#13;<br \/>\n        # Convert input token IDs to embeddings&#13;<br \/>\n        hidden_states = self.embed_tokens(input_ids)&#13;<br \/>\n        # Process through all transformer layers, then the final norm layer&#13;<br \/>\n        for layer in self.layers:&#13;<br \/>\n            # Previously:&#13;<br \/>\n            # hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)&#13;<br \/>\n            hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states, self.rotary_emb, attn_mask)&#13;<br \/>\n        hidden_states = self.norm(hidden_states)&#13;<br \/>\n        # Return the final hidden states&#13;<br \/>\n        return hidden_states<\/textarea><\/p>\n<div class=\"urvanov-syntax-highlighter-main\" style=\"\">\n<table class=\"crayon-table\">\n<tr class=\"urvanov-syntax-highlighter-row\">\n<td class=\"crayon-nums \" data-settings=\"show\">\n<div class=\"urvanov-syntax-highlighter-nums-content\" style=\"font-size: 12px !important; line-height: 15px !important;\">\n<p>1<\/p>\n<p>2<\/p>\n<p>3<\/p>\n<p>4<\/p>\n<p>5<\/p>\n<p>6<\/p>\n<p>7<\/p>\n<p>8<\/p>\n<p>9<\/p>\n<p>10<\/p>\n<p>11<\/p>\n<p>12<\/p>\n<p>13<\/p>\n<p>14<\/p>\n<p>15<\/p>\n<p>16<\/p>\n<p>17<\/p>\n<p>18<\/p>\n<p>19<\/p>\n<p>20<\/p>\n<p>21<\/p>\n<p>22<\/p>\n<p>23<\/p>\n<p>24<\/p>\n<\/div>\n<\/td>\n<td class=\"urvanov-syntax-highlighter-code\">\n<div class=\"crayon-pre\" style=\"font-size: 12px !important; line-height: 15px !important; -moz-tab-size:4; -o-tab-size:4; -webkit-tab-size:4; tab-size:4;\">\n<p><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-sy\">.<\/span><\/p>\n<p><span class=\"crayon-t\">class<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">LlamaModel<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">nn<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">Module<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-e\">def <\/span><span class=\"crayon-e\">__init__<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">LlamaConfig<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">-&gt;<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">None<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-r\">super<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">__init__<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">rotary_emb<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">RotaryPositionEncoding<\/span><span class=\"crayon-sy\">(<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">hidden_size<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-c\">\/\/ config.num_attention_heads,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">max_position_embeddings<\/span><span class=\"crayon-sy\">,<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">embed_tokens<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">nn<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">Embedding<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">vocab_size<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">hidden_size<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">layers<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">nn<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">ModuleList<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-sy\">[<\/span><span class=\"crayon-e\">LlamaDecoderLayer<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-st\">for<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-i\">_<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-st\">in<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">range<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">num_hidden_layers<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-sy\">]<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">norm<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">nn<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">RMSNorm<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">config<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">hidden_size<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">eps<\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-cn\">1e<\/span><span class=\"crayon-o\">&#8211;<\/span><span class=\"crayon-cn\">5<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p>\u00a0<\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-e\">def <\/span><span class=\"crayon-e\">forward<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">input_ids<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">Tensor<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">attn_mask<\/span><span class=\"crayon-o\">:<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">Tensor<\/span><span class=\"crayon-sy\">)<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">-&gt;<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">Tensor<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># Convert input token IDs to embeddings<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">hidden_states<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">embed_tokens<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">input_ids<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># Process through all transformer layers, then the final norm layer<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-st\">for<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-e\">layer <\/span><span class=\"crayon-st\">in<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">layers<\/span><span class=\"crayon-o\">:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># Previously:<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">hidden_states<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">torch<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">utils<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">checkpoint<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">checkpoint<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">layer<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">hidden_states<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-v\">rotary_emb<\/span><span class=\"crayon-sy\">,<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">attn_mask<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-v\">hidden_states<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-o\">=<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-r\">self<\/span><span class=\"crayon-sy\">.<\/span><span class=\"crayon-e\">norm<\/span><span class=\"crayon-sy\">(<\/span><span class=\"crayon-v\">hidden_states<\/span><span class=\"crayon-sy\">)<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-p\"># Return the final hidden states<\/span><\/p>\n<p><span class=\"crayon-h\">\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0\u00a0<\/span><span class=\"crayon-st\">return<\/span><span class=\"crayon-h\"> <\/span><span class=\"crayon-v\">hidden_states<\/span><\/p>\n<\/div>\n<\/td>\n<\/tr>\n<\/table><\/div>\n<\/p><\/div>\n<p>Only one line of code needs to change: in the for-loop under the <code>forward()<\/code> function, instead of calling the transformer block directly, use <code>torch.utils.checkpoint.checkpoint()<\/code>. This runs the forward pass with gradient checkpointing, discarding all intermediate results and retaining only the block\u2019s input and output. During the backward pass, the intermediate results are temporarily recomputed using the input.<\/p>\n<h2>Further readings<\/h2>\n<p>Below are some resources that you may find useful:<\/p>\n<h2>Summary<\/h2>\n<p>In this article, you learned techniques for training a language model with limited memory. Specifically, you learned that:<\/p>\n<ul>\n<li>Several types of floating-point numbers exist, with some using less memory than others.<\/li>\n<li>Mixed-precision training automatically uses lower-precision floating-point numbers without sacrificing accuracy on critical operations.<\/li>\n<li>Gradient checkpointing trades time for memory during training.<\/li>\n<\/ul><\/div>\n","protected":false},"excerpt":{"rendered":"<p>Training a language model is memory-intensive, not only because the model itself is large but also because training data batches often contain long sequences. Training a model with limited memory is challenging. In this article, you will learn techniques that enable model training in memory-constrained environments. In particular, you will learn about: Low-precision floating-point numbers<\/p>\n","protected":false},"author":1,"featured_media":178034,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[42],"tags":[],"class_list":{"0":"post-178033","1":"post","2":"type-post","3":"status-publish","4":"format-standard","5":"has-post-thumbnail","7":"category-ai"},"yoast_head":"<!-- This site is optimized with the Yoast SEO plugin v26.4 - https:\/\/yoast.com\/wordpress\/plugins\/seo\/ -->\n<title>Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing - Ktromedia<\/title>\n<meta name=\"robots\" content=\"index, follow, max-snippet:-1, max-image-preview:large, max-video-preview:-1\" \/>\n<link rel=\"canonical\" href=\"http:\/\/ktromedia.com\/?p=178033\" \/>\n<meta property=\"og:locale\" content=\"en_US\" \/>\n<meta property=\"og:type\" content=\"article\" \/>\n<meta property=\"og:title\" content=\"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing - Ktromedia\" \/>\n<meta property=\"og:description\" content=\"Training a language model is memory-intensive, not only because the model itself is large but also because training data batches often contain long sequences. Training a model with limited memory is challenging. In this article, you will learn techniques that enable model training in memory-constrained environments. In particular, you will learn about: Low-precision floating-point numbers\" \/>\n<meta property=\"og:url\" content=\"http:\/\/ktromedia.com\/?p=178033\" \/>\n<meta property=\"og:site_name\" content=\"Ktromedia\" \/>\n<meta property=\"article:publisher\" content=\"https:\/\/www.facebook.com\/KTROMedia\/\" \/>\n<meta property=\"article:published_time\" content=\"2026-01-20T14:37:50+00:00\" \/>\n<meta property=\"og:image\" content=\"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg\" \/>\n\t<meta property=\"og:image:width\" content=\"2560\" \/>\n\t<meta property=\"og:image:height\" content=\"1790\" \/>\n\t<meta property=\"og:image:type\" content=\"image\/jpeg\" \/>\n<meta name=\"author\" content=\"KTRO TEAM\" \/>\n<meta name=\"twitter:card\" content=\"summary_large_image\" \/>\n<meta name=\"twitter:label1\" content=\"Written by\" \/>\n\t<meta name=\"twitter:data1\" content=\"KTRO TEAM\" \/>\n\t<meta name=\"twitter:label2\" content=\"Est. reading time\" \/>\n\t<meta name=\"twitter:data2\" content=\"15 minutes\" \/>\n<script type=\"application\/ld+json\" class=\"yoast-schema-graph\">{\"@context\":\"https:\/\/schema.org\",\"@graph\":[{\"@type\":\"Article\",\"@id\":\"http:\/\/ktromedia.com\/?p=178033#article\",\"isPartOf\":{\"@id\":\"http:\/\/ktromedia.com\/?p=178033\"},\"author\":{\"name\":\"KTRO TEAM\",\"@id\":\"https:\/\/ktromedia.com\/#\/schema\/person\/612bf2fbac107722ea365932cdd35f5b\"},\"headline\":\"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing\",\"datePublished\":\"2026-01-20T14:37:50+00:00\",\"mainEntityOfPage\":{\"@id\":\"http:\/\/ktromedia.com\/?p=178033\"},\"wordCount\":2863,\"commentCount\":0,\"publisher\":{\"@id\":\"https:\/\/ktromedia.com\/#organization\"},\"image\":{\"@id\":\"http:\/\/ktromedia.com\/?p=178033#primaryimage\"},\"thumbnailUrl\":\"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg\",\"articleSection\":[\"\u4eba\u5de5\u667a\u80fd\"],\"inLanguage\":\"en-US\",\"potentialAction\":[{\"@type\":\"CommentAction\",\"name\":\"Comment\",\"target\":[\"http:\/\/ktromedia.com\/?p=178033#respond\"]}]},{\"@type\":\"WebPage\",\"@id\":\"http:\/\/ktromedia.com\/?p=178033\",\"url\":\"http:\/\/ktromedia.com\/?p=178033\",\"name\":\"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing - Ktromedia\",\"isPartOf\":{\"@id\":\"https:\/\/ktromedia.com\/#website\"},\"primaryImageOfPage\":{\"@id\":\"http:\/\/ktromedia.com\/?p=178033#primaryimage\"},\"image\":{\"@id\":\"http:\/\/ktromedia.com\/?p=178033#primaryimage\"},\"thumbnailUrl\":\"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg\",\"datePublished\":\"2026-01-20T14:37:50+00:00\",\"breadcrumb\":{\"@id\":\"http:\/\/ktromedia.com\/?p=178033#breadcrumb\"},\"inLanguage\":\"en-US\",\"potentialAction\":[{\"@type\":\"ReadAction\",\"target\":[\"http:\/\/ktromedia.com\/?p=178033\"]}]},{\"@type\":\"ImageObject\",\"inLanguage\":\"en-US\",\"@id\":\"http:\/\/ktromedia.com\/?p=178033#primaryimage\",\"url\":\"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg\",\"contentUrl\":\"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg\",\"width\":2560,\"height\":1790},{\"@type\":\"BreadcrumbList\",\"@id\":\"http:\/\/ktromedia.com\/?p=178033#breadcrumb\",\"itemListElement\":[{\"@type\":\"ListItem\",\"position\":1,\"name\":\"Home\",\"item\":\"https:\/\/ktromedia.com\/\"},{\"@type\":\"ListItem\",\"position\":2,\"name\":\"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing\"}]},{\"@type\":\"WebSite\",\"@id\":\"https:\/\/ktromedia.com\/#website\",\"url\":\"https:\/\/ktromedia.com\/\",\"name\":\"Ktromedia\",\"description\":\"KTRO MEDIA Crypto News\",\"publisher\":{\"@id\":\"https:\/\/ktromedia.com\/#organization\"},\"potentialAction\":[{\"@type\":\"SearchAction\",\"target\":{\"@type\":\"EntryPoint\",\"urlTemplate\":\"https:\/\/ktromedia.com\/?s={search_term_string}\"},\"query-input\":{\"@type\":\"PropertyValueSpecification\",\"valueRequired\":true,\"valueName\":\"search_term_string\"}}],\"inLanguage\":\"en-US\"},{\"@type\":\"Organization\",\"@id\":\"https:\/\/ktromedia.com\/#organization\",\"name\":\"Ktromedia\",\"url\":\"https:\/\/ktromedia.com\/\",\"logo\":{\"@type\":\"ImageObject\",\"inLanguage\":\"en-US\",\"@id\":\"https:\/\/ktromedia.com\/#\/schema\/logo\/image\/\",\"url\":\"https:\/\/ktromedia.com\/wp-content\/uploads\/2025\/11\/ktroicon.png\",\"contentUrl\":\"https:\/\/ktromedia.com\/wp-content\/uploads\/2025\/11\/ktroicon.png\",\"width\":250,\"height\":250,\"caption\":\"Ktromedia\"},\"image\":{\"@id\":\"https:\/\/ktromedia.com\/#\/schema\/logo\/image\/\"},\"sameAs\":[\"https:\/\/www.facebook.com\/KTROMedia\/\",\"https:\/\/www.linkedin.com\/company\/ktro-media\/\",\"https:\/\/t.me\/ktrogroup\"]},{\"@type\":\"Person\",\"@id\":\"https:\/\/ktromedia.com\/#\/schema\/person\/612bf2fbac107722ea365932cdd35f5b\",\"name\":\"KTRO TEAM\",\"image\":{\"@type\":\"ImageObject\",\"inLanguage\":\"en-US\",\"@id\":\"https:\/\/ktromedia.com\/#\/schema\/person\/image\/\",\"url\":\"http:\/\/ktromedia.com\/wp-content\/uploads\/2025\/10\/cropped-Untitled-design-7-1-150x150.png\",\"contentUrl\":\"http:\/\/ktromedia.com\/wp-content\/uploads\/2025\/10\/cropped-Untitled-design-7-1-150x150.png\",\"caption\":\"KTRO TEAM\"},\"description\":\"KTRO MEDIA \u662f\u4e00\u5bb6\u5168\u7403\u6027\u7684\u534e\u6587WEB3\u5a92\u4f53\u516c\u53f8\u3002\u6211\u4eec\u81f4\u529b\u4e8e\u4e3a\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u9886\u57df\u63d0\u4f9b\u6700\u65b0\u7684\u65b0\u95fb\u3001\u89c1\u89e3\u548c\u8d8b\u52bf\u5206\u6790\u3002\u6211\u4eec\u7684\u5b97\u65e8\u662f\u4e3a\u5168\u7403\u7528\u6237\u63d0\u4f9b\u9ad8\u8d28\u91cf\u3001\u5168\u9762\u7684\u8d44\u8baf\u670d\u52a1\uff0c\u8ba9\u4ed6\u4eec\u66f4\u597d\u5730\u4e86\u89e3\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u884c\u4e1a\u7684\u6700\u65b0\u52a8\u6001\u3002\u6211\u4eec\u4e5f\u5e0c\u671b\u80fd\u5e2e\u5230\u66f4\u591a\u4f18\u79c0\u7684WEB3\u4ea7\u54c1\u627e\u5230\u66f4\u591a\u66f4\u597d\u7684\u8d44\u6e90\u597d\u8ba9\u8fd9\u9886\u57df\u53d8\u5f97\u66f4\u6210\u719f\u3002 \u6211\u4eec\u7684\u62a5\u9053\u8303\u56f4\u6db5\u76d6\u4e86\u533a\u5757\u94fe\u3001\u52a0\u5bc6\u8d27\u5e01\u3001\u667a\u80fd\u5408\u7ea6\u3001DeFi\u3001NFT \u548c Web3 \u751f\u6001\u7cfb\u7edf\u7b49\u9886\u57df\u3002\u6211\u4eec\u7684\u62a5\u9053\u4e0d\u4ec5\u6765\u81ea\u884c\u4e1a\u5185\u7684\u4e13\u5bb6\uff0c\u5148\u950b\u8005\u4e5f\u5305\u62ec\u4e86\u6211\u4eec\u81ea\u5df1\u7684\u5206\u6790\u548c\u89c2\u70b9\u3002\u6211\u4eec\u5728\u5404\u4e2a\u56fd\u5bb6\u548c\u5730\u533a\u90fd\u8bbe\u6709\u56e2\u961f\uff0c\u4e3a\u8bfb\u8005\u63d0\u4f9b\u672c\u5730\u5316\u7684\u62a5\u9053\u548c\u5206\u6790\u3002 \u9664\u4e86\u65b0\u95fb\u62a5\u9053\uff0c\u6211\u4eec\u8fd8\u63d0\u4f9b\u5e02\u573a\u7814\u7a76\u548c\u54a8\u8be2\u670d\u52a1\u3002\u6211\u4eec\u7684\u4e13\u4e1a\u56e2\u961f\u53ef\u4ee5\u4e3a\u60a8\u63d0\u4f9b\u6709\u5173\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u884c\u4e1a\u7684\u6df1\u5165\u5206\u6790\u548c\u5e02\u573a\u8d8b\u52bf\uff0c\u5e2e\u52a9\u60a8\u505a\u51fa\u66f4\u660e\u667a\u7684\u6295\u8d44\u51b3\u7b56\u3002 \u6211\u4eec\u7684\u4f7f\u547d\u662f\u6210\u4e3a\u5168\u7403\u534e\u6587\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u884c\u4e1a\u6700\u53d7\u4fe1\u8d56\u7684\u4fe1\u606f\u6765\u6e90\u4e4b\u4e00\u3002\u6211\u4eec\u5c06\u7ee7\u7eed\u4e0d\u65ad\u52aa\u529b\uff0c\u4e3a\u8bfb\u8005\u63d0\u4f9b\u6700\u65b0\u3001\u6700\u5168\u9762\u3001\u6700\u53ef\u9760\u7684\u4fe1\u606f\u670d\u52a1\u3002\",\"sameAs\":[\"https:\/\/ktromedia.com\"],\"url\":\"http:\/\/ktromedia.com\/?author=1\"}]}<\/script>\n<!-- \/ Yoast SEO plugin. -->","yoast_head_json":{"title":"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing - Ktromedia","robots":{"index":"index","follow":"follow","max-snippet":"max-snippet:-1","max-image-preview":"max-image-preview:large","max-video-preview":"max-video-preview:-1"},"canonical":"http:\/\/ktromedia.com\/?p=178033","og_locale":"en_US","og_type":"article","og_title":"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing - Ktromedia","og_description":"Training a language model is memory-intensive, not only because the model itself is large but also because training data batches often contain long sequences. Training a model with limited memory is challenging. In this article, you will learn techniques that enable model training in memory-constrained environments. In particular, you will learn about: Low-precision floating-point numbers","og_url":"http:\/\/ktromedia.com\/?p=178033","og_site_name":"Ktromedia","article_publisher":"https:\/\/www.facebook.com\/KTROMedia\/","article_published_time":"2026-01-20T14:37:50+00:00","og_image":[{"width":2560,"height":1790,"url":"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg","type":"image\/jpeg"}],"author":"KTRO TEAM","twitter_card":"summary_large_image","twitter_misc":{"Written by":"KTRO TEAM","Est. reading time":"15 minutes"},"schema":{"@context":"https:\/\/schema.org","@graph":[{"@type":"Article","@id":"http:\/\/ktromedia.com\/?p=178033#article","isPartOf":{"@id":"http:\/\/ktromedia.com\/?p=178033"},"author":{"name":"KTRO TEAM","@id":"https:\/\/ktromedia.com\/#\/schema\/person\/612bf2fbac107722ea365932cdd35f5b"},"headline":"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing","datePublished":"2026-01-20T14:37:50+00:00","mainEntityOfPage":{"@id":"http:\/\/ktromedia.com\/?p=178033"},"wordCount":2863,"commentCount":0,"publisher":{"@id":"https:\/\/ktromedia.com\/#organization"},"image":{"@id":"http:\/\/ktromedia.com\/?p=178033#primaryimage"},"thumbnailUrl":"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg","articleSection":["\u4eba\u5de5\u667a\u80fd"],"inLanguage":"en-US","potentialAction":[{"@type":"CommentAction","name":"Comment","target":["http:\/\/ktromedia.com\/?p=178033#respond"]}]},{"@type":"WebPage","@id":"http:\/\/ktromedia.com\/?p=178033","url":"http:\/\/ktromedia.com\/?p=178033","name":"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing - Ktromedia","isPartOf":{"@id":"https:\/\/ktromedia.com\/#website"},"primaryImageOfPage":{"@id":"http:\/\/ktromedia.com\/?p=178033#primaryimage"},"image":{"@id":"http:\/\/ktromedia.com\/?p=178033#primaryimage"},"thumbnailUrl":"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg","datePublished":"2026-01-20T14:37:50+00:00","breadcrumb":{"@id":"http:\/\/ktromedia.com\/?p=178033#breadcrumb"},"inLanguage":"en-US","potentialAction":[{"@type":"ReadAction","target":["http:\/\/ktromedia.com\/?p=178033"]}]},{"@type":"ImageObject","inLanguage":"en-US","@id":"http:\/\/ktromedia.com\/?p=178033#primaryimage","url":"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg","contentUrl":"http:\/\/ktromedia.com\/wp-content\/uploads\/2026\/01\/Training-a-Model-with-Limited-Memory-using-Mixed-Precision-and.jpg","width":2560,"height":1790},{"@type":"BreadcrumbList","@id":"http:\/\/ktromedia.com\/?p=178033#breadcrumb","itemListElement":[{"@type":"ListItem","position":1,"name":"Home","item":"https:\/\/ktromedia.com\/"},{"@type":"ListItem","position":2,"name":"Training a Model with Limited Memory using Mixed Precision and Gradient Checkpointing"}]},{"@type":"WebSite","@id":"https:\/\/ktromedia.com\/#website","url":"https:\/\/ktromedia.com\/","name":"Ktromedia","description":"KTRO MEDIA Crypto News","publisher":{"@id":"https:\/\/ktromedia.com\/#organization"},"potentialAction":[{"@type":"SearchAction","target":{"@type":"EntryPoint","urlTemplate":"https:\/\/ktromedia.com\/?s={search_term_string}"},"query-input":{"@type":"PropertyValueSpecification","valueRequired":true,"valueName":"search_term_string"}}],"inLanguage":"en-US"},{"@type":"Organization","@id":"https:\/\/ktromedia.com\/#organization","name":"Ktromedia","url":"https:\/\/ktromedia.com\/","logo":{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/ktromedia.com\/#\/schema\/logo\/image\/","url":"https:\/\/ktromedia.com\/wp-content\/uploads\/2025\/11\/ktroicon.png","contentUrl":"https:\/\/ktromedia.com\/wp-content\/uploads\/2025\/11\/ktroicon.png","width":250,"height":250,"caption":"Ktromedia"},"image":{"@id":"https:\/\/ktromedia.com\/#\/schema\/logo\/image\/"},"sameAs":["https:\/\/www.facebook.com\/KTROMedia\/","https:\/\/www.linkedin.com\/company\/ktro-media\/","https:\/\/t.me\/ktrogroup"]},{"@type":"Person","@id":"https:\/\/ktromedia.com\/#\/schema\/person\/612bf2fbac107722ea365932cdd35f5b","name":"KTRO TEAM","image":{"@type":"ImageObject","inLanguage":"en-US","@id":"https:\/\/ktromedia.com\/#\/schema\/person\/image\/","url":"http:\/\/ktromedia.com\/wp-content\/uploads\/2025\/10\/cropped-Untitled-design-7-1-150x150.png","contentUrl":"http:\/\/ktromedia.com\/wp-content\/uploads\/2025\/10\/cropped-Untitled-design-7-1-150x150.png","caption":"KTRO TEAM"},"description":"KTRO MEDIA \u662f\u4e00\u5bb6\u5168\u7403\u6027\u7684\u534e\u6587WEB3\u5a92\u4f53\u516c\u53f8\u3002\u6211\u4eec\u81f4\u529b\u4e8e\u4e3a\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u9886\u57df\u63d0\u4f9b\u6700\u65b0\u7684\u65b0\u95fb\u3001\u89c1\u89e3\u548c\u8d8b\u52bf\u5206\u6790\u3002\u6211\u4eec\u7684\u5b97\u65e8\u662f\u4e3a\u5168\u7403\u7528\u6237\u63d0\u4f9b\u9ad8\u8d28\u91cf\u3001\u5168\u9762\u7684\u8d44\u8baf\u670d\u52a1\uff0c\u8ba9\u4ed6\u4eec\u66f4\u597d\u5730\u4e86\u89e3\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u884c\u4e1a\u7684\u6700\u65b0\u52a8\u6001\u3002\u6211\u4eec\u4e5f\u5e0c\u671b\u80fd\u5e2e\u5230\u66f4\u591a\u4f18\u79c0\u7684WEB3\u4ea7\u54c1\u627e\u5230\u66f4\u591a\u66f4\u597d\u7684\u8d44\u6e90\u597d\u8ba9\u8fd9\u9886\u57df\u53d8\u5f97\u66f4\u6210\u719f\u3002 \u6211\u4eec\u7684\u62a5\u9053\u8303\u56f4\u6db5\u76d6\u4e86\u533a\u5757\u94fe\u3001\u52a0\u5bc6\u8d27\u5e01\u3001\u667a\u80fd\u5408\u7ea6\u3001DeFi\u3001NFT \u548c Web3 \u751f\u6001\u7cfb\u7edf\u7b49\u9886\u57df\u3002\u6211\u4eec\u7684\u62a5\u9053\u4e0d\u4ec5\u6765\u81ea\u884c\u4e1a\u5185\u7684\u4e13\u5bb6\uff0c\u5148\u950b\u8005\u4e5f\u5305\u62ec\u4e86\u6211\u4eec\u81ea\u5df1\u7684\u5206\u6790\u548c\u89c2\u70b9\u3002\u6211\u4eec\u5728\u5404\u4e2a\u56fd\u5bb6\u548c\u5730\u533a\u90fd\u8bbe\u6709\u56e2\u961f\uff0c\u4e3a\u8bfb\u8005\u63d0\u4f9b\u672c\u5730\u5316\u7684\u62a5\u9053\u548c\u5206\u6790\u3002 \u9664\u4e86\u65b0\u95fb\u62a5\u9053\uff0c\u6211\u4eec\u8fd8\u63d0\u4f9b\u5e02\u573a\u7814\u7a76\u548c\u54a8\u8be2\u670d\u52a1\u3002\u6211\u4eec\u7684\u4e13\u4e1a\u56e2\u961f\u53ef\u4ee5\u4e3a\u60a8\u63d0\u4f9b\u6709\u5173\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u884c\u4e1a\u7684\u6df1\u5165\u5206\u6790\u548c\u5e02\u573a\u8d8b\u52bf\uff0c\u5e2e\u52a9\u60a8\u505a\u51fa\u66f4\u660e\u667a\u7684\u6295\u8d44\u51b3\u7b56\u3002 \u6211\u4eec\u7684\u4f7f\u547d\u662f\u6210\u4e3a\u5168\u7403\u534e\u6587\u533a\u5757\u94fe\u548c\u91d1\u878d\u79d1\u6280\u884c\u4e1a\u6700\u53d7\u4fe1\u8d56\u7684\u4fe1\u606f\u6765\u6e90\u4e4b\u4e00\u3002\u6211\u4eec\u5c06\u7ee7\u7eed\u4e0d\u65ad\u52aa\u529b\uff0c\u4e3a\u8bfb\u8005\u63d0\u4f9b\u6700\u65b0\u3001\u6700\u5168\u9762\u3001\u6700\u53ef\u9760\u7684\u4fe1\u606f\u670d\u52a1\u3002","sameAs":["https:\/\/ktromedia.com"],"url":"http:\/\/ktromedia.com\/?author=1"}]}},"_links":{"self":[{"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/posts\/178033","targetHints":{"allow":["GET"]}}],"collection":[{"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"http:\/\/ktromedia.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=178033"}],"version-history":[{"count":1,"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/posts\/178033\/revisions"}],"predecessor-version":[{"id":178035,"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/posts\/178033\/revisions\/178035"}],"wp:featuredmedia":[{"embeddable":true,"href":"http:\/\/ktromedia.com\/index.php?rest_route=\/wp\/v2\/media\/178034"}],"wp:attachment":[{"href":"http:\/\/ktromedia.com\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=178033"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"http:\/\/ktromedia.com\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=178033"},{"taxonomy":"post_tag","embeddable":true,"href":"http:\/\/ktromedia.com\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=178033"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}