diff --git a/examples/ml_dtypes/README.md b/examples/ml_dtypes/README.md new file mode 100644 index 0000000..fcb7370 --- /dev/null +++ b/examples/ml_dtypes/README.md @@ -0,0 +1,18 @@ +# ml_dtypes Examples + +Each sub-directory contains a self-contained example. The order in +which the examples are to appear is specified in `order.json` (an +array of directory names in the expected order). + +In each example directory you'll find: + +* `config.toml` - must conform to the specification outlined here: + https://docs.pyscript.net/latest/user-guide/configuration/ This is + parsed and ultimately turned into a JSON representation as part of + the package's API object. +* `setup.py` - Python code for contextual and environmental setup, + NOT SEEN BY THE END USER, but is run before the `code.py` code is + evaluated. Allows us to create useful (IPython) shims, avoid + repeating boilerplate and whatnot. +* `code.py` - the actual code added to the editor which forms the + practical example of using the package. diff --git a/examples/ml_dtypes/bfloat16_basics/code.py b/examples/ml_dtypes/bfloat16_basics/code.py new file mode 100644 index 0000000..3b6ded7 --- /dev/null +++ b/examples/ml_dtypes/bfloat16_basics/code.py @@ -0,0 +1,53 @@ +""" +A first look at ml_dtypes: the bfloat16 dtype. + +ml_dtypes provides NumPy dtype extensions used widely in machine +learning (bfloat16, several float8 variants, sub-byte ints, etc.). +Importing the package also registers the dtypes with NumPy, so you +can use them anywhere a regular dtype is accepted. + +Docs: https://github.com/jax-ml/ml_dtypes +""" +import numpy as np +from ml_dtypes import bfloat16 +from IPython.core.display import display, HTML + +heading("1. Creating a bfloat16 array") +note( + "bfloat16 is a 16-bit float with the same exponent range as " + "float32 (8 bits) but only 7 mantissa bits. It trades precision " + "for range, which suits ML workloads." +) + +# Construct a bfloat16 array. Once ml_dtypes is imported, NumPy +# also recognizes the dtype by string name. +weights_f32 = rng.normal(size=8).astype(np.float32) +weights_bf16 = weights_f32.astype(bfloat16) + +note("The same eight values, side by side at different precisions:") +comparison = np.empty( + (3, 8), + dtype=object, +) +comparison[0] = [f"{x:.6f}" for x in weights_f32] +comparison[1] = [f"{x:.6f}" for x in weights_bf16.astype(np.float32)] +comparison[2] = [f"{float(a) - float(b):+.2e}" + for a, b in zip(weights_f32, weights_bf16.astype(np.float32))] +rows = "".join( + "" + f"{label}" + + "".join(f"{cell}" for cell in comparison[i]) + "" + for i, label in enumerate(["float32", "bfloat16", "delta"]) +) +display(HTML(f"{rows}
"), append=True) + +heading("2. Memory savings") +note( + f"float32 uses {weights_f32.nbytes} bytes for 8 values; " + f"bfloat16 uses {weights_bf16.nbytes}. Half the memory, " + "same exponent range." +) + +heading("3. dtype is registered with NumPy") +# After importing ml_dtypes, np.dtype('bfloat16') resolves. +note(f"np.dtype('bfloat16') -> {np.dtype('bfloat16')}") +note(f"itemsize: {np.dtype('bfloat16').itemsize} bytes") diff --git a/examples/ml_dtypes/bfloat16_basics/config.toml b/examples/ml_dtypes/bfloat16_basics/config.toml new file mode 100644 index 0000000..226c052 --- /dev/null +++ b/examples/ml_dtypes/bfloat16_basics/config.toml @@ -0,0 +1 @@ +packages = ["ml_dtypes", "numpy"] diff --git a/examples/ml_dtypes/bfloat16_basics/setup.py b/examples/ml_dtypes/bfloat16_basics/setup.py new file mode 100644 index 0000000..c3dcde2 --- /dev/null +++ b/examples/ml_dtypes/bfloat16_basics/setup.py @@ -0,0 +1,48 @@ +""" +Shim IPython's display API onto PyScript so example code written in a +Jupyter/IPython idiom runs unmodified in the browser. +""" + +import sys +import types +import js +from pyscript import window, HTML, display as _display + +js.alert = window.alert + + +def display(*args, **kwargs): + """Wrap pyscript.display so output lands in the example target.""" + return _display( + *args, **kwargs, target=__pyscript_display_target__, + ) + + +ipython = types.ModuleType("IPython") +core = types.ModuleType("IPython.core") +core_display = types.ModuleType("IPython.core.display") +core_display.display = display +core_display.HTML = HTML +ipython.core = core +core.display = core_display +ipython.get_ipython = lambda: None +ipython.display = core_display +sys.modules["IPython"] = ipython +sys.modules["IPython.core"] = core +sys.modules["IPython.core.display"] = core_display +sys.modules["IPython.display"] = core_display + + +def heading(text, level=2): + display(HTML(f"{text}"), append=True) + + +def note(text): + display(HTML(f"

{text}

"), append=True) + + +# Package imports for this example. +import numpy as np +from ml_dtypes import bfloat16 + +rng = np.random.default_rng(0) diff --git a/examples/ml_dtypes/float8_landscape/code.py b/examples/ml_dtypes/float8_landscape/code.py new file mode 100644 index 0000000..b47d100 --- /dev/null +++ b/examples/ml_dtypes/float8_landscape/code.py @@ -0,0 +1,80 @@ +# --------------------------------------------------------------------- +# A tour of the smaller dtypes: float8 variants and int4. +# --------------------------------------------------------------------- + +heading("How many distinct values does each format represent?") +note( + "float8_e4m3fn favors precision (4 exponent, 3 mantissa). " + "float8_e5m2 favors range (5 exponent, 2 mantissa). " + "Both fit in a single byte, but they cover the number line " + "very differently." +) + +# Enumerate every byte pattern for each 8-bit float type and decode +# it to float32. This gives us the complete value set of each format. +all_bytes = np.arange(256, dtype=np.uint8) + + +def decode_all(dtype): + """Return the float32 values of all 256 byte patterns for a dtype.""" + return all_bytes.view(dtype).astype(np.float32) + + +e4m3_values = decode_all(float8_e4m3fn) +e5m2_values = decode_all(float8_e5m2) + +# Drop NaNs for plotting. +e4m3_finite = e4m3_values[np.isfinite(e4m3_values)] +e5m2_finite = e5m2_values[np.isfinite(e5m2_values)] + +note( + f"float8_e4m3fn finite values: {len(e4m3_finite)}, " + f"max = {e4m3_finite.max():g}
" + f"float8_e5m2 finite values: {len(e5m2_finite)}, " + f"max = {e5m2_finite.max():g}" +) + +# Plot the value distributions on a symlog axis so we can see how +# each format spaces its representable points. +fig, ax = plt.subplots(figsize=(9, 3.5)) +ax.scatter(e4m3_finite, np.full_like(e4m3_finite, 1.0), + s=10, color="steelblue", label="float8_e4m3fn") +ax.scatter(e5m2_finite, np.full_like(e5m2_finite, 0.0), + s=10, color="crimson", label="float8_e5m2") +ax.set_xscale("symlog", linthresh=1e-3) +ax.set_yticks([0, 1]) +ax.set_yticklabels(["e5m2", "e4m3fn"]) +ax.set_xlabel("Representable value (symlog)") +ax.set_title("Where each float8 format places its points") +ax.legend(loc="upper center") +ax.grid(True, axis="x", alpha=0.3) +fig.tight_layout() +display(fig, append=True) + +heading("Quantizing weights to int4") +note( + "Sub-byte integers like int4 are stored unpacked (one per byte) " + "but only the low 4 bits matter. Here we quantize a small " + "weight vector to the int4 range [-8, 7]." +) + +weights = rng.normal(scale=2.0, size=12).astype(np.float32) +scale = np.max(np.abs(weights)) / 7.0 +quantized = np.round(weights / scale).clip(-8, 7).astype(int4) +recovered = quantized.astype(np.float32) * scale + +rows = "original" + "".join( + f"{w:+.2f}" for w in weights +) + "" +rows += "int4 code" + "".join( + f"{int(q)}" for q in quantized +) + "" +rows += "dequantized" + "".join( + f"{w:+.2f}" for w in recovered +) + "" +display(HTML(f"{rows}
"), append=True) + +note( + f"Mean absolute error after round-trip: " + f"{np.abs(weights - recovered).mean():.3f}" +) diff --git a/examples/ml_dtypes/float8_landscape/config.toml b/examples/ml_dtypes/float8_landscape/config.toml new file mode 100644 index 0000000..f940b3d --- /dev/null +++ b/examples/ml_dtypes/float8_landscape/config.toml @@ -0,0 +1 @@ +packages = ["ml_dtypes", "numpy", "matplotlib"] diff --git a/examples/ml_dtypes/float8_landscape/setup.py b/examples/ml_dtypes/float8_landscape/setup.py new file mode 100644 index 0000000..1136bd5 --- /dev/null +++ b/examples/ml_dtypes/float8_landscape/setup.py @@ -0,0 +1,26 @@ +"""Setup for the third example. No IPython shim here.""" +import js +from pyscript import window, HTML, display as _display + +js.alert = window.alert + + +def display(*args, **kwargs): + return _display( + *args, **kwargs, target=__pyscript_display_target__, + ) + + +def heading(text, level=2): + display(HTML(f"{text}"), append=True) + + +def note(text): + display(HTML(f"

{text}

"), append=True) + + +import numpy as np +import matplotlib.pyplot as plt +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2, int4 + +rng = np.random.default_rng(0) diff --git a/examples/ml_dtypes/low_precision_summation/code.py b/examples/ml_dtypes/low_precision_summation/code.py new file mode 100644 index 0000000..df4c575 --- /dev/null +++ b/examples/ml_dtypes/low_precision_summation/code.py @@ -0,0 +1,57 @@ +# --------------------------------------------------------------------- +# Quirks of low-precision arithmetic: a cautionary tale. +# --------------------------------------------------------------------- + +heading("Summing 10,000 small numbers in bfloat16") +note( + "We draw 10,000 values uniformly from [0, 1). The true sum " + "is around 5,000. Watch what happens when bfloat16 accumulates." +) + +samples = rng.uniform(size=10_000).astype(bfloat16) + +# Naive: accumulate in bfloat16. Once the running total reaches 256, +# adding a value below 1 has no effect (the next representable +# bfloat16 above 256 is 258). +naive_sum = samples.sum() + +# Better: accumulate in float32 and cast back at the end. +careful_sum = samples.sum(dtype=np.float32).astype(bfloat16) + +# Reference: full float32 sum. +true_sum = samples.astype(np.float32).sum() + +note( + f"Naive bfloat16 sum: {float(naive_sum):.1f}
" + f"Accumulated in float32, cast to bfloat16: " + f"{float(careful_sum):.1f}
" + f"True (float32) sum: {float(true_sum):.1f}" +) + +heading("Why? bfloat16's spacing grows with magnitude") +note( + "Floating point numbers get sparser as they grow. Above 256, " + "consecutive bfloat16 values are 2 apart, so adding 0.5 is " + "indistinguishable from adding zero." +) + +# Visualize the gap between adjacent bfloat16 values across magnitudes. +magnitudes = np.array([2.0 ** k for k in range(-4, 16)], dtype=bfloat16) +next_up = np.array( + [np.nextafter(m, bfloat16(np.inf)) for m in magnitudes], + dtype=bfloat16, +) +spacing = (next_up.astype(np.float32) - magnitudes.astype(np.float32)) + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.loglog(magnitudes.astype(np.float32), spacing, marker="o", + color="crimson", label="bfloat16 spacing") +ax.axvline(256, color="gray", linestyle="--", + label="x = 256 (spacing crosses 1)") +ax.set_xlabel("Value (x)") +ax.set_ylabel("Distance to next bfloat16") +ax.set_title("Bfloat16 quantization gap vs. magnitude") +ax.legend() +ax.grid(True, which="both", alpha=0.3) +fig.tight_layout() +display(fig, append=True) diff --git a/examples/ml_dtypes/low_precision_summation/config.toml b/examples/ml_dtypes/low_precision_summation/config.toml new file mode 100644 index 0000000..f940b3d --- /dev/null +++ b/examples/ml_dtypes/low_precision_summation/config.toml @@ -0,0 +1 @@ +packages = ["ml_dtypes", "numpy", "matplotlib"] diff --git a/examples/ml_dtypes/low_precision_summation/setup.py b/examples/ml_dtypes/low_precision_summation/setup.py new file mode 100644 index 0000000..c1a20dd --- /dev/null +++ b/examples/ml_dtypes/low_precision_summation/setup.py @@ -0,0 +1,26 @@ +"""Setup for the second example. No IPython shim here.""" +import js +from pyscript import window, HTML, display as _display + +js.alert = window.alert + + +def display(*args, **kwargs): + return _display( + *args, **kwargs, target=__pyscript_display_target__, + ) + + +def heading(text, level=2): + display(HTML(f"{text}"), append=True) + + +def note(text): + display(HTML(f"

{text}

"), append=True) + + +import numpy as np +import matplotlib.pyplot as plt +from ml_dtypes import bfloat16 + +rng = np.random.default_rng(0) diff --git a/examples/ml_dtypes/order.json b/examples/ml_dtypes/order.json new file mode 100644 index 0000000..3a0096f --- /dev/null +++ b/examples/ml_dtypes/order.json @@ -0,0 +1,5 @@ +[ + "bfloat16_basics", + "low_precision_summation", + "float8_landscape" +]