Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/ml_dtypes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ml_dtypes Examples

Each sub-directory contains a self-contained example. The order in
which the examples are to appear is specified in `order.json` (an
array of directory names in the expected order).

In each example directory you'll find:

* `config.toml` - must conform to the specification outlined here:
https://docs.pyscript.net/latest/user-guide/configuration/ This is
parsed and ultimately turned into a JSON representation as part of
the package's API object.
* `setup.py` - Python code for contextual and environmental setup,
NOT SEEN BY THE END USER, but is run before the `code.py` code is
evaluated. Allows us to create useful (IPython) shims, avoid
repeating boilerplate and whatnot.
* `code.py` - the actual code added to the editor which forms the
practical example of using the package.
53 changes: 53 additions & 0 deletions examples/ml_dtypes/bfloat16_basics/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
A first look at ml_dtypes: the bfloat16 dtype.

ml_dtypes provides NumPy dtype extensions used widely in machine
learning (bfloat16, several float8 variants, sub-byte ints, etc.).
Importing the package also registers the dtypes with NumPy, so you
can use them anywhere a regular dtype is accepted.

Docs: https://github.com/jax-ml/ml_dtypes
"""
import numpy as np
from ml_dtypes import bfloat16
from IPython.core.display import display, HTML

heading("1. Creating a bfloat16 array")
note(
"bfloat16 is a 16-bit float with the same exponent range as "
"float32 (8 bits) but only 7 mantissa bits. It trades precision "
"for range, which suits ML workloads."
)

# Construct a bfloat16 array. Once ml_dtypes is imported, NumPy
# also recognizes the dtype by string name.
weights_f32 = rng.normal(size=8).astype(np.float32)
weights_bf16 = weights_f32.astype(bfloat16)

note("The same eight values, side by side at different precisions:")
comparison = np.empty(
(3, 8),
dtype=object,
)
comparison[0] = [f"{x:.6f}" for x in weights_f32]
comparison[1] = [f"{x:.6f}" for x in weights_bf16.astype(np.float32)]
comparison[2] = [f"{float(a) - float(b):+.2e}"
for a, b in zip(weights_f32, weights_bf16.astype(np.float32))]
rows = "".join(
"<tr>" + f"<th>{label}</th>" +
"".join(f"<td>{cell}</td>" for cell in comparison[i]) + "</tr>"
for i, label in enumerate(["float32", "bfloat16", "delta"])
)
display(HTML(f"<table>{rows}</table>"), append=True)

heading("2. Memory savings")
note(
f"float32 uses {weights_f32.nbytes} bytes for 8 values; "
f"bfloat16 uses {weights_bf16.nbytes}. Half the memory, "
"same exponent range."
)

heading("3. dtype is registered with NumPy")
# After importing ml_dtypes, np.dtype('bfloat16') resolves.
note(f"np.dtype('bfloat16') -> <code>{np.dtype('bfloat16')}</code>")
note(f"itemsize: <code>{np.dtype('bfloat16').itemsize}</code> bytes")
1 change: 1 addition & 0 deletions examples/ml_dtypes/bfloat16_basics/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
packages = ["ml_dtypes", "numpy"]
48 changes: 48 additions & 0 deletions examples/ml_dtypes/bfloat16_basics/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Shim IPython's display API onto PyScript so example code written in a
Jupyter/IPython idiom runs unmodified in the browser.
"""

import sys
import types
import js
from pyscript import window, HTML, display as _display

js.alert = window.alert


def display(*args, **kwargs):
"""Wrap pyscript.display so output lands in the example target."""
return _display(
*args, **kwargs, target=__pyscript_display_target__,
)


ipython = types.ModuleType("IPython")
core = types.ModuleType("IPython.core")
core_display = types.ModuleType("IPython.core.display")
core_display.display = display
core_display.HTML = HTML
ipython.core = core
core.display = core_display
ipython.get_ipython = lambda: None
ipython.display = core_display
sys.modules["IPython"] = ipython
sys.modules["IPython.core"] = core
sys.modules["IPython.core.display"] = core_display
sys.modules["IPython.display"] = core_display


def heading(text, level=2):
display(HTML(f"<h{level}>{text}</h{level}>"), append=True)


def note(text):
display(HTML(f"<p>{text}</p>"), append=True)


# Package imports for this example.
import numpy as np
from ml_dtypes import bfloat16

rng = np.random.default_rng(0)
80 changes: 80 additions & 0 deletions examples/ml_dtypes/float8_landscape/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# ---------------------------------------------------------------------
# A tour of the smaller dtypes: float8 variants and int4.
# ---------------------------------------------------------------------

heading("How many distinct values does each format represent?")
note(
"float8_e4m3fn favors precision (4 exponent, 3 mantissa). "
"float8_e5m2 favors range (5 exponent, 2 mantissa). "
"Both fit in a single byte, but they cover the number line "
"very differently."
)

# Enumerate every byte pattern for each 8-bit float type and decode
# it to float32. This gives us the complete value set of each format.
all_bytes = np.arange(256, dtype=np.uint8)


def decode_all(dtype):
"""Return the float32 values of all 256 byte patterns for a dtype."""
return all_bytes.view(dtype).astype(np.float32)


e4m3_values = decode_all(float8_e4m3fn)
e5m2_values = decode_all(float8_e5m2)

# Drop NaNs for plotting.
e4m3_finite = e4m3_values[np.isfinite(e4m3_values)]
e5m2_finite = e5m2_values[np.isfinite(e5m2_values)]

note(
f"float8_e4m3fn finite values: <strong>{len(e4m3_finite)}</strong>, "
f"max = {e4m3_finite.max():g}<br>"
f"float8_e5m2 finite values: <strong>{len(e5m2_finite)}</strong>, "
f"max = {e5m2_finite.max():g}"
)

# Plot the value distributions on a symlog axis so we can see how
# each format spaces its representable points.
fig, ax = plt.subplots(figsize=(9, 3.5))
ax.scatter(e4m3_finite, np.full_like(e4m3_finite, 1.0),
s=10, color="steelblue", label="float8_e4m3fn")
ax.scatter(e5m2_finite, np.full_like(e5m2_finite, 0.0),
s=10, color="crimson", label="float8_e5m2")
ax.set_xscale("symlog", linthresh=1e-3)
ax.set_yticks([0, 1])
ax.set_yticklabels(["e5m2", "e4m3fn"])
ax.set_xlabel("Representable value (symlog)")
ax.set_title("Where each float8 format places its points")
ax.legend(loc="upper center")
ax.grid(True, axis="x", alpha=0.3)
fig.tight_layout()
display(fig, append=True)

heading("Quantizing weights to int4")
note(
"Sub-byte integers like int4 are stored unpacked (one per byte) "
"but only the low 4 bits matter. Here we quantize a small "
"weight vector to the int4 range [-8, 7]."
)

weights = rng.normal(scale=2.0, size=12).astype(np.float32)
scale = np.max(np.abs(weights)) / 7.0
quantized = np.round(weights / scale).clip(-8, 7).astype(int4)
recovered = quantized.astype(np.float32) * scale

rows = "<tr><th>original</th>" + "".join(
f"<td>{w:+.2f}</td>" for w in weights
) + "</tr>"
rows += "<tr><th>int4 code</th>" + "".join(
f"<td>{int(q)}</td>" for q in quantized
) + "</tr>"
rows += "<tr><th>dequantized</th>" + "".join(
f"<td>{w:+.2f}</td>" for w in recovered
) + "</tr>"
display(HTML(f"<table>{rows}</table>"), append=True)

note(
f"Mean absolute error after round-trip: "
f"<strong>{np.abs(weights - recovered).mean():.3f}</strong>"
)
1 change: 1 addition & 0 deletions examples/ml_dtypes/float8_landscape/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
packages = ["ml_dtypes", "numpy", "matplotlib"]
26 changes: 26 additions & 0 deletions examples/ml_dtypes/float8_landscape/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Setup for the third example. No IPython shim here."""
import js
from pyscript import window, HTML, display as _display

js.alert = window.alert


def display(*args, **kwargs):
return _display(
*args, **kwargs, target=__pyscript_display_target__,
)


def heading(text, level=2):
display(HTML(f"<h{level}>{text}</h{level}>"), append=True)


def note(text):
display(HTML(f"<p>{text}</p>"), append=True)


import numpy as np
import matplotlib.pyplot as plt
from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2, int4

rng = np.random.default_rng(0)
57 changes: 57 additions & 0 deletions examples/ml_dtypes/low_precision_summation/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# ---------------------------------------------------------------------
# Quirks of low-precision arithmetic: a cautionary tale.
# ---------------------------------------------------------------------

heading("Summing 10,000 small numbers in bfloat16")
note(
"We draw 10,000 values uniformly from [0, 1). The true sum "
"is around 5,000. Watch what happens when bfloat16 accumulates."
)

samples = rng.uniform(size=10_000).astype(bfloat16)

# Naive: accumulate in bfloat16. Once the running total reaches 256,
# adding a value below 1 has no effect (the next representable
# bfloat16 above 256 is 258).
naive_sum = samples.sum()

# Better: accumulate in float32 and cast back at the end.
careful_sum = samples.sum(dtype=np.float32).astype(bfloat16)

# Reference: full float32 sum.
true_sum = samples.astype(np.float32).sum()

note(
f"Naive bfloat16 sum: <strong>{float(naive_sum):.1f}</strong><br>"
f"Accumulated in float32, cast to bfloat16: "
f"<strong>{float(careful_sum):.1f}</strong><br>"
f"True (float32) sum: <strong>{float(true_sum):.1f}</strong>"
)

heading("Why? bfloat16's spacing grows with magnitude")
note(
"Floating point numbers get sparser as they grow. Above 256, "
"consecutive bfloat16 values are 2 apart, so adding 0.5 is "
"indistinguishable from adding zero."
)

# Visualize the gap between adjacent bfloat16 values across magnitudes.
magnitudes = np.array([2.0 ** k for k in range(-4, 16)], dtype=bfloat16)
next_up = np.array(
[np.nextafter(m, bfloat16(np.inf)) for m in magnitudes],
dtype=bfloat16,
)
spacing = (next_up.astype(np.float32) - magnitudes.astype(np.float32))

fig, ax = plt.subplots(figsize=(8, 4))
ax.loglog(magnitudes.astype(np.float32), spacing, marker="o",
color="crimson", label="bfloat16 spacing")
ax.axvline(256, color="gray", linestyle="--",
label="x = 256 (spacing crosses 1)")
ax.set_xlabel("Value (x)")
ax.set_ylabel("Distance to next bfloat16")
ax.set_title("Bfloat16 quantization gap vs. magnitude")
ax.legend()
ax.grid(True, which="both", alpha=0.3)
fig.tight_layout()
display(fig, append=True)
1 change: 1 addition & 0 deletions examples/ml_dtypes/low_precision_summation/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
packages = ["ml_dtypes", "numpy", "matplotlib"]
26 changes: 26 additions & 0 deletions examples/ml_dtypes/low_precision_summation/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Setup for the second example. No IPython shim here."""
import js
from pyscript import window, HTML, display as _display

js.alert = window.alert


def display(*args, **kwargs):
return _display(
*args, **kwargs, target=__pyscript_display_target__,
)


def heading(text, level=2):
display(HTML(f"<h{level}>{text}</h{level}>"), append=True)


def note(text):
display(HTML(f"<p>{text}</p>"), append=True)


import numpy as np
import matplotlib.pyplot as plt
from ml_dtypes import bfloat16

rng = np.random.default_rng(0)
5 changes: 5 additions & 0 deletions examples/ml_dtypes/order.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
"bfloat16_basics",
"low_precision_summation",
"float8_landscape"
]