Skip to content

Commit 1937d48

Browse files
authored
Generate API (#19530)
* API Generator for Keras * API Generator for Keras * Generates API Gen via api_gen.sh * Remove recursive import of _tf_keras * Generate API Files via api_gen.sh
1 parent 559f1dd commit 1937d48

File tree

754 files changed

+6889
-3479
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

754 files changed

+6889
-3479
lines changed

.github/workflows/actions.yml

+16-7
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ jobs:
2424
KERAS_HOME: .github/workflows/config/${{ matrix.backend }}
2525
steps:
2626
- uses: actions/checkout@v4
27-
- name: Check for changes in keras/applications
27+
- name: Check for changes in keras/src/applications
2828
uses: dorny/paths-filter@v3
2929
id: filter
3030
with:
3131
filters: |
3232
applications:
33-
- 'keras/applications/**'
33+
- 'keras/src/applications/**'
3434
- name: Set up Python
3535
uses: actions/setup-python@v5
3636
with:
@@ -49,13 +49,13 @@ jobs:
4949
run: |
5050
pip install -r requirements.txt --progress-bar off --upgrade
5151
pip uninstall -y keras keras-nightly
52-
pip install tf_keras==2.16.0rc0 --progress-bar off --upgrade
52+
pip install tf_keras==2.16.0 --progress-bar off --upgrade
5353
pip install -e "." --progress-bar off --upgrade
5454
- name: Test applications with pytest
5555
if: ${{ steps.filter.outputs.applications == 'true' }}
5656
run: |
57-
pytest keras/applications --cov=keras/applications
58-
coverage xml --include='keras/applications/*' -o apps-coverage.xml
57+
pytest keras/src/applications --cov=keras/src/applications
58+
coverage xml --include='keras/src/applications/*' -o apps-coverage.xml
5959
- name: Codecov keras.applications
6060
if: ${{ steps.filter.outputs.applications == 'true' }}
6161
uses: codecov/codecov-action@v4
@@ -80,8 +80,8 @@ jobs:
8080
pytest integration_tests/torch_workflow_test.py
8181
- name: Test with pytest
8282
run: |
83-
pytest keras --ignore keras/applications --cov=keras
84-
coverage xml --omit='keras/applications/*' -o core-coverage.xml
83+
pytest keras --ignore keras/src/applications --cov=keras
84+
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
8585
- name: Codecov keras
8686
uses: codecov/codecov-action@v4
8787
with:
@@ -115,5 +115,14 @@ jobs:
115115
pip install -r requirements.txt --progress-bar off --upgrade
116116
pip uninstall -y keras keras-nightly
117117
pip install -e "." --progress-bar off --upgrade
118+
- name: Check for API changes
119+
run: |
120+
bash shell/api_gen.sh
121+
git status
122+
clean=$(git status | grep "nothing to commit")
123+
if [ -z "$clean" ]; then
124+
echo "Please run shell/api_gen.sh to generate API."
125+
exit 1
126+
fi
118127
- name: Lint
119128
run: bash shell/lint.sh

.github/workflows/nightly.yml

+11-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
pytest integration_tests/torch_workflow_test.py
5656
- name: Test with pytest
5757
run: |
58-
pytest keras --ignore keras/applications --cov=keras
58+
pytest keras --ignore keras/src/applications --cov=keras
5959
6060
format:
6161
name: Check the code format
@@ -81,6 +81,15 @@ jobs:
8181
pip install -r requirements.txt --progress-bar off --upgrade
8282
pip uninstall -y keras keras-nightly
8383
pip install -e "." --progress-bar off --upgrade
84+
- name: Check for API changes
85+
run: |
86+
bash shell/api_gen.sh
87+
git status
88+
clean=$(git status | grep "nothing to commit")
89+
if [ -z "$clean" ]; then
90+
echo "Please run shell/api_gen.sh to generate API."
91+
exit 1
92+
fi
8493
- name: Lint
8594
run: bash shell/lint.sh
8695

@@ -108,4 +117,4 @@ jobs:
108117
with:
109118
password: ${{ secrets.PYPI_NIGHTLY_API_TOKEN }}
110119
packages-dir: dist/
111-
verbose: true
120+
verbose: true

api_gen.py

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""Script to generate keras public API in `keras/api` directory.
2+
3+
Usage:
4+
5+
Run via `./shell/api_gen.sh`.
6+
It generates API and formats user and generated APIs.
7+
"""
8+
9+
import os
10+
import shutil
11+
12+
import namex
13+
14+
package = "keras"
15+
16+
17+
def ignore_files(_, filenames):
18+
return [f for f in filenames if f.endswith("_test.py")]
19+
20+
21+
def create_legacy_directory():
22+
API_DIR = os.path.join(package, "api")
23+
# Make keras/_tf_keras/ by copying keras/
24+
tf_keras_dirpath_parent = os.path.join(API_DIR, "_tf_keras")
25+
tf_keras_dirpath = os.path.join(tf_keras_dirpath_parent, "keras")
26+
os.makedirs(tf_keras_dirpath, exist_ok=True)
27+
with open(os.path.join(tf_keras_dirpath_parent, "__init__.py"), "w") as f:
28+
f.write("from keras.api._tf_keras import keras\n")
29+
with open(os.path.join(API_DIR, "__init__.py")) as f:
30+
init_file = f.read()
31+
init_file = init_file.replace(
32+
"from keras.api import _legacy",
33+
"from keras.api import _tf_keras",
34+
)
35+
with open(os.path.join(API_DIR, "__init__.py"), "w") as f:
36+
f.write(init_file)
37+
# Remove the import of `_tf_keras` in `keras/_tf_keras/keras/__init__.py`
38+
init_file = init_file.replace("from keras.api import _tf_keras\n", "\n")
39+
with open(os.path.join(tf_keras_dirpath, "__init__.py"), "w") as f:
40+
f.write(init_file)
41+
for dirname in os.listdir(API_DIR):
42+
dirpath = os.path.join(API_DIR, dirname)
43+
if os.path.isdir(dirpath) and dirname not in (
44+
"_legacy",
45+
"_tf_keras",
46+
"src",
47+
):
48+
destpath = os.path.join(tf_keras_dirpath, dirname)
49+
if os.path.exists(destpath):
50+
shutil.rmtree(destpath)
51+
shutil.copytree(
52+
dirpath,
53+
destpath,
54+
ignore=ignore_files,
55+
)
56+
57+
# Copy keras/_legacy/ file contents to keras/_tf_keras/keras
58+
legacy_submodules = [
59+
path[:-3]
60+
for path in os.listdir(os.path.join(package, "src", "legacy"))
61+
if path.endswith(".py")
62+
]
63+
legacy_submodules += [
64+
path
65+
for path in os.listdir(os.path.join(package, "src", "legacy"))
66+
if os.path.isdir(os.path.join(package, "src", "legacy", path))
67+
]
68+
69+
for root, _, fnames in os.walk(os.path.join(package, "_legacy")):
70+
for fname in fnames:
71+
if fname.endswith(".py"):
72+
legacy_fpath = os.path.join(root, fname)
73+
tf_keras_root = root.replace("/_legacy", "/_tf_keras/keras")
74+
core_api_fpath = os.path.join(
75+
root.replace("/_legacy", ""), fname
76+
)
77+
if not os.path.exists(tf_keras_root):
78+
os.makedirs(tf_keras_root)
79+
tf_keras_fpath = os.path.join(tf_keras_root, fname)
80+
with open(legacy_fpath) as f:
81+
legacy_contents = f.read()
82+
legacy_contents = legacy_contents.replace(
83+
"keras.api._legacy", "keras.api._tf_keras.keras"
84+
)
85+
if os.path.exists(core_api_fpath):
86+
with open(core_api_fpath) as f:
87+
core_api_contents = f.read()
88+
core_api_contents = core_api_contents.replace(
89+
"from keras.api import _tf_keras\n", ""
90+
)
91+
for legacy_submodule in legacy_submodules:
92+
core_api_contents = core_api_contents.replace(
93+
f"from keras.api import {legacy_submodule}\n",
94+
"",
95+
)
96+
core_api_contents = core_api_contents.replace(
97+
f"keras.api.{legacy_submodule}",
98+
f"keras.api._tf_keras.keras.{legacy_submodule}",
99+
)
100+
legacy_contents = core_api_contents + "\n" + legacy_contents
101+
with open(tf_keras_fpath, "w") as f:
102+
f.write(legacy_contents)
103+
104+
# Delete keras/api/_legacy/
105+
shutil.rmtree(os.path.join(API_DIR, "_legacy"))
106+
107+
108+
def export_version_string():
109+
API_INIT = os.path.join(package, "api", "__init__.py")
110+
with open(API_INIT) as f:
111+
contents = f.read()
112+
with open(API_INIT, "w") as f:
113+
contents += "from keras.src.version import __version__\n"
114+
f.write(contents)
115+
116+
117+
def update_package_init():
118+
contents = """
119+
# Import everything from /api/ into keras.
120+
from keras.api import * # noqa: F403
121+
from keras.api import __version__ # Import * ignores names start with "_".
122+
123+
import os
124+
125+
# Add everything in /api/ to the module search path.
126+
__path__.append(os.path.join(os.path.dirname(__file__), "api")) # noqa: F405
127+
128+
# Don't pollute namespace.
129+
del os
130+
131+
# Never autocomplete `.src` or `.api` on an imported keras object.
132+
def __dir__():
133+
keys = dict.fromkeys((globals().keys()))
134+
keys.pop("src")
135+
keys.pop("api")
136+
return list(keys)
137+
138+
139+
# Don't import `.src` or `.api` during `from keras import *`.
140+
__all__ = [
141+
name
142+
for name in globals().keys()
143+
if not (name.startswith("_") or name in ("src", "api"))
144+
]"""
145+
with open(os.path.join(package, "__init__.py")) as f:
146+
init_contents = f.read()
147+
with open(os.path.join(package, "__init__.py"), "w") as f:
148+
f.write(init_contents.replace("\nfrom keras import api", contents))
149+
150+
151+
if __name__ == "__main__":
152+
# Backup the `keras/__init__.py` and restore it on error in api gen.
153+
os.makedirs(os.path.join(package, "api"), exist_ok=True)
154+
init_fname = os.path.join(package, "__init__.py")
155+
backup_init_fname = os.path.join(package, "__init__.py.bak")
156+
try:
157+
if os.path.exists(init_fname):
158+
shutil.move(init_fname, backup_init_fname)
159+
# Generates `keras/api` directory.
160+
namex.generate_api_files(
161+
"keras", code_directory="src", target_directory="api"
162+
)
163+
# Creates `keras/__init__.py` importing from `keras/api`
164+
update_package_init()
165+
except Exception as e:
166+
if os.path.exists(backup_init_fname):
167+
shutil.move(backup_init_fname, init_fname)
168+
raise e
169+
finally:
170+
if os.path.exists(backup_init_fname):
171+
os.remove(backup_init_fname)
172+
# Add __version__ to keras package
173+
export_version_string()
174+
# Creates `_tf_keras` with full keras API
175+
create_legacy_directory()

conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest # noqa: E402
1616

17-
from keras.backend import backend # noqa: E402
17+
from keras.src.backend import backend # noqa: E402
1818

1919

2020
def pytest_configure(config):

integration_tests/basic_full_flow.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import pytest
33

44
import keras
5-
from keras import layers
6-
from keras import losses
7-
from keras import metrics
8-
from keras import optimizers
9-
from keras import testing
5+
from keras.src import layers
6+
from keras.src import losses
7+
from keras.src import metrics
8+
from keras.src import optimizers
9+
from keras.src import testing
1010

1111

1212
class MyModel(keras.Model):

integration_tests/dataset_tests/boston_housing_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from keras import testing
2-
from keras.datasets import boston_housing
1+
from keras.src import testing
2+
from keras.src.datasets import boston_housing
33

44

55
class BostonHousingTest(testing.TestCase):

integration_tests/dataset_tests/california_housing_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from keras import testing
2-
from keras.datasets import california_housing
1+
from keras.src import testing
2+
from keras.src.datasets import california_housing
33

44

55
class CaliforniaHousingTest(testing.TestCase):

integration_tests/dataset_tests/cifar100_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from keras import testing
4-
from keras.datasets import cifar100
3+
from keras.src import testing
4+
from keras.src.datasets import cifar100
55

66

77
class Cifar100LoadDataTest(testing.TestCase):

integration_tests/dataset_tests/cifar10_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from keras import testing
4-
from keras.datasets import cifar10
3+
from keras.src import testing
4+
from keras.src.datasets import cifar10
55

66

77
class Cifar10LoadDataTest(testing.TestCase):

integration_tests/dataset_tests/fashion_mnist_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from keras import testing
4-
from keras.datasets import fashion_mnist
3+
from keras.src import testing
4+
from keras.src.datasets import fashion_mnist
55

66

77
class FashionMnistLoadDataTest(testing.TestCase):

integration_tests/dataset_tests/imdb_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from keras import testing
4-
from keras.datasets import imdb
3+
from keras.src import testing
4+
from keras.src.datasets import imdb
55

66

77
class ImdbLoadDataTest(testing.TestCase):

integration_tests/dataset_tests/mnist_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from keras import testing
4-
from keras.datasets import mnist
3+
from keras.src import testing
4+
from keras.src.datasets import mnist
55

66

77
class MnistLoadDataTest(testing.TestCase):

integration_tests/dataset_tests/reuters_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

3-
from keras import testing
4-
from keras.datasets import reuters
3+
from keras.src import testing
4+
from keras.src.datasets import reuters
55

66

77
class ReutersLoadDataTest(testing.TestCase):

0 commit comments

Comments
 (0)