Skip to content

Commit 36fd2e1

Browse files
gneculalearned_optimization authors
authored andcommitted
No public description
PiperOrigin-RevId: 681093047
1 parent 4bcaeb0 commit 36fd2e1

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

docs/notebooks/summary_tutorial.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -997,9 +997,9 @@
997997
"id": "jNt9CNJf2HJN"
998998
},
999999
"source": [
1000-
"### jax.experimental.host_callback\n",
1000+
"### jax external callbacks\n",
10011001
"\n",
1002-
"Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.\n",
1002+
"Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.\n",
10031003
"\n",
10041004
"One can use this to print which is a quick way to get data out of a network."
10051005
]
@@ -1025,13 +1025,12 @@
10251025
}
10261026
],
10271027
"source": [
1028-
"from jax.experimental import host_callback as hcb\n",
10291028
"\n",
10301029
"\n",
10311030
"def loss(parameters):\n",
10321031
" loss = jnp.mean(parameters**2)\n",
10331032
" to_look_at = jnp.mean(123.)\n",
1034-
" hcb.id_print(to_look_at, name=\"to_look_at\")\n",
1033+
" jax.debug.print(\"to_look_at={}\", to_look_at)\n",
10351034
" return loss\n",
10361035
"\n",
10371036
"\n",

docs/notebooks/summary_tutorial.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,9 @@ print(to_look_at)
461461

462462
+++ {"id": "jNt9CNJf2HJN"}
463463

464-
### jax.experimental.host_callback
464+
### jax external callbacks
465465

466-
Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.
466+
Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.
467467

468468
One can use this to print which is a quick way to get data out of a network.
469469

@@ -474,13 +474,12 @@ colab:
474474
id: 1Ih2LxP22MZD
475475
outputId: 0dd0b8ec-2c9e-414d-eadf-843122b7b8ab
476476
---
477-
from jax.experimental import host_callback as hcb
478477
479478
480479
def loss(parameters):
481480
loss = jnp.mean(parameters**2)
482481
to_look_at = jnp.mean(123.)
483-
hcb.id_print(to_look_at, name="to_look_at")
482+
jax.debug.print("to_look_at={}", to_look_at)
484483
return loss
485484
486485

docs/notebooks/summary_tutorial.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,21 @@ def loss(parameters):
348348
print(to_look_at)
349349

350350
# + [markdown] id="jNt9CNJf2HJN"
351-
# ### jax.experimental.host_callback
351+
# ### jax external callbacks
352352
#
353-
# Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.
353+
# Jax has some support to send data back from an accelerator back to the host
354+
# while a jax program is running. This is exposed in
355+
# https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.
354356
#
355357
# One can use this to print which is a quick way to get data out of a network.
356358

357359
# + colab={"base_uri": "https://localhost:8080/"} id="1Ih2LxP22MZD" outputId="0dd0b8ec-2c9e-414d-eadf-843122b7b8ab"
358-
from jax.experimental import host_callback as hcb
359360

360361

361362
def loss(parameters):
362363
loss = jnp.mean(parameters**2)
363364
to_look_at = jnp.mean(123.)
364-
hcb.id_print(to_look_at, name="to_look_at")
365+
jax.debug.print("to_look_at={}", to_look_at)
365366
return loss
366367

367368

0 commit comments

Comments
 (0)