Skip to content

Commit e736de2

Browse files
Merge branch 'master' into fix-node23
2 parents 93a10af + 407c6e5 commit e736de2

12 files changed

Lines changed: 176 additions & 12 deletions

File tree

.github/workflows/tfjs-ci.yml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
name: TFJS Continuous Integration
2+
3+
on:
4+
push:
5+
branches: [ "master" ]
6+
pull_request:
7+
branches: [ "master" ]
8+
workflow_dispatch:
9+
10+
permissions:
11+
contents: read
12+
13+
jobs:
14+
test:
15+
runs-on: ubuntu-latest
16+
steps:
17+
- uses: bazel-contrib/setup-bazel@0.14.0
18+
with:
19+
# Avoid downloading Bazel every time.
20+
bazelisk-cache: true
21+
# Store build cache per workflow.
22+
disk-cache: ${{ github.workflow }}-cpu
23+
# Share repository cache between workflows.
24+
repository-cache: true
25+
- uses: actions/checkout@v4
26+
- name: Test TFJS CPU
27+
uses: actions/setup-node@v4
28+
with:
29+
node-version: 20.x
30+
cache: 'npm'
31+
- run: npm i -g yarn
32+
- run: yarn install
33+
- run: yarn test-cpu
34+
35+
test-gpu-mac:
36+
runs-on: macos-latest-xlarge # consumer gpu
37+
steps:
38+
- uses: bazel-contrib/setup-bazel@0.14.0
39+
with:
40+
# Avoid downloading Bazel every time.
41+
bazelisk-cache: true
42+
# Store build cache per workflow.
43+
disk-cache: ${{ github.workflow }}-gpu-mac
44+
# Share repository cache between workflows.
45+
repository-cache: true
46+
- uses: actions/checkout@v4
47+
- name: Test TFJS GPU
48+
uses: actions/setup-node@v4
49+
with:
50+
node-version: 20.x
51+
cache: 'npm'
52+
- run: npm i -g yarn
53+
- run: yarn install
54+
- run: yarn test-gpu

BUILD.bazel

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,31 @@ headless_flag(
4949
)
5050

5151
test_suite(
52-
name = "tests",
52+
name = "tests_cpu",
5353
tests = [
5454
"//tfjs-backend-cpu:tests",
5555
"//tfjs-backend-wasm:tests",
56-
"//tfjs-backend-webgl:tests",
5756
"//tfjs-converter:tests",
5857
"//tfjs-core:tests",
5958
"//tfjs-data:tests",
60-
"//tfjs-layers:tests",
6159
"//tfjs-tfdf:tests",
6260
"//tfjs-tflite:tests",
6361
],
6462
)
63+
64+
test_suite(
65+
name = "tests_gpu",
66+
tests = [
67+
"//tfjs-backend-webgl:tests",
68+
"//tfjs-backend-webgpu:tests",
69+
"//tfjs-layers:tests",
70+
],
71+
)
72+
73+
test_suite(
74+
name = "tests",
75+
tests = [
76+
":tests_cpu",
77+
":tests_gpu",
78+
],
79+
)

package.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@
8080
"scripts": {
8181
"lint": "tslint -p tsconfig_tslint.json",
8282
"test": "bazel test //:tests",
83+
"test-cpu": "bazel test --test_output=all //:tests_cpu",
84+
"test-gpu": "bazel test --test_output=all //:tests_gpu",
85+
"test-non-bazel": "cd link-package && yarn build-deps-for --all",
86+
"build": "cd link-package && yarn build",
8387
"test-packages-ci": "yarn generate-cloudbuild-for-packages && ./scripts/run-build.sh",
8488
"nightly-cloudbuild": "NIGHTLY=true yarn generate-cloudbuild-for-packages && gcloud builds submit . --config=cloudbuild_generated.yml --substitutions=_NIGHTLY=true",
8589
"generate-cloudbuild-for-packages": "ts-node -s ./scripts/generate_cloudbuild_for_packages.ts",

tfjs-backend-webgl/BUILD.bazel

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ tfjs_web_test(
116116
"bs_chrome_mac",
117117
"bs_android_10",
118118
],
119+
local_browser = select({
120+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
121+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
122+
"//conditions:default": "chrome_webgpu",
123+
}),
119124
static_files = STATIC_FILES,
120125
)
121126

@@ -137,6 +142,11 @@ tfjs_web_test(
137142
"bs_safari_mac",
138143
"bs_ios_12",
139144
],
145+
local_browser = select({
146+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
147+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
148+
"//conditions:default": "chrome_webgpu",
149+
}),
140150
static_files = STATIC_FILES,
141151
)
142152

@@ -156,6 +166,11 @@ tfjs_web_test(
156166
],
157167
headless = False,
158168
presubmit_browsers = [], # Only run in nightly
169+
local_browser = select({
170+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
171+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
172+
"//conditions:default": "chrome_webgpu",
173+
}),
159174
static_files = STATIC_FILES,
160175
)
161176

@@ -175,6 +190,11 @@ tfjs_web_test(
175190
],
176191
headless = False,
177192
presubmit_browsers = [], # Only run in nightly
193+
local_browser = select({
194+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
195+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
196+
"//conditions:default": "chrome_webgpu",
197+
}),
178198
static_files = STATIC_FILES,
179199
)
180200

@@ -194,6 +214,11 @@ tfjs_web_test(
194214
],
195215
headless = False,
196216
presubmit_browsers = [], # Only run in nightly
217+
local_browser = select({
218+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
219+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
220+
"//conditions:default": "chrome_webgpu",
221+
}),
197222
static_files = STATIC_FILES,
198223
)
199224

@@ -213,6 +238,11 @@ tfjs_web_test(
213238
],
214239
headless = False,
215240
presubmit_browsers = [], # Only run in nightly
241+
local_browser = select({
242+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
243+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
244+
"//conditions:default": "chrome_webgpu",
245+
}),
216246
static_files = STATIC_FILES,
217247
)
218248

tfjs-backend-webgpu/BUILD.bazel

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,10 @@ tfjs_web_test(
116116
}),
117117
static_files = STATIC_FILES,
118118
)
119+
120+
test_suite(
121+
name = "tests",
122+
tests = [
123+
":tfjs-backend-webgpu_test",
124+
],
125+
)

tfjs-backend-webgpu/src/backend_webgpu.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,16 +594,19 @@ export class WebGPUBackend extends KernelBackend {
594594
* @param dataId The source tensor.
595595
*/
596596
override readToGPU(dataId: DataId): GPUData {
597-
const srcTensorData = this.tensorMap.get(dataId);
598-
const {values, dtype, shape, resource} = srcTensorData;
597+
let srcTensorData = this.tensorMap.get(dataId);
598+
const {values, dtype, shape} = srcTensorData;
599+
let resource = srcTensorData.resource;
599600

600601
if (dtype === 'complex64') {
601602
throw new Error('Does not support reading buffer for complex64 dtype.');
602603
}
603604

604605
if (resource == null) {
605606
if (values != null) {
606-
throw new Error('Data is not on GPU but on CPU.');
607+
this.uploadToGPU(dataId);
608+
srcTensorData = this.tensorMap.get(dataId);
609+
resource = srcTensorData.resource;
607610
} else {
608611
throw new Error('There is no data on GPU or CPU.');
609612
}

tfjs-backend-webgpu/src/backend_webgpu_test.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ describeWebGPU('backend webgpu', () => {
200200
await c3.data();
201201
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
202202
});
203+
204+
it('dataToGPU uploads to GPU if the tensor is on CPU', async () => {
205+
const webGPUBackend = (tf.backend() as WebGPUBackend);
206+
const data = [1,2,3,4,5];
207+
const tensor = tf.tensor1d(data);
208+
const res = tensor.dataToGPU();
209+
expect(res.buffer).toBeDefined();
210+
const resData = await webGPUBackend.getBufferData(res.buffer);
211+
const values = tf.util.convertBackendValuesAndArrayBuffer(
212+
resData, res.tensorRef.dtype);
213+
expectArraysEqual(values, data);
214+
});
203215
});
204216

205217
describeWebGPU('backendWebGPU', () => {

tfjs-backend-webgpu/src/setup_test.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ const TEST_FILTERS: TestFilter[] = [
3333
'gradient', // gradient function not found.
3434
]
3535
},
36+
{
37+
startsWith: 'pow',
38+
excludes: [
39+
'int32' // MacOS precision issue
40+
],
41+
},
3642
{
3743
startsWith: 'exp ',
3844
excludes: [
@@ -62,6 +68,13 @@ const TEST_FILTERS: TestFilter[] = [
6268
excludes: [
6369
'gradients', // Failing on MacOS
6470
'gradient with clones', // Failing on MacOS
71+
'propagates NaNs', // Failing on MacOS
72+
],
73+
},
74+
{
75+
startsWith: 'sin ',
76+
excludes: [
77+
'propagates NaNs', // Failing on MacOS
6578
],
6679
},
6780
{

tfjs-converter/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ jax_conversion.convert_jax(
319319
```
320320

321321
See
322-
[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)
322+
[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)
323323
for more details on the exact syntax for this argument.
324324

325325
When converting JAX models, you can also pass any [options that

tfjs-layers/BUILD.bazel

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ tfjs_web_test(
5959
],
6060
headless = False,
6161
seed = "12345",
62+
local_browser = select({
63+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
64+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
65+
"//conditions:default": "chrome_webgpu",
66+
}),
6267
static_files = [
6368
# Listed here so sourcemaps are served
6469
"//tfjs-layers/src:tfjs-layers_test_bundle",
@@ -79,6 +84,11 @@ tfjs_web_test(
7984
],
8085
headless = False,
8186
seed = "12345",
87+
local_browser = select({
88+
"@bazel_tools//src/conditions:linux_x86_64": "chrome_webgpu_linux",
89+
"@bazel_tools//src/conditions:windows": "chrome_webgpu",
90+
"//conditions:default": "chrome_webgpu",
91+
}),
8292
static_files = [
8393
# Listed here so sourcemaps are served
8494
"//tfjs-layers/src:tfjs-layers_test_bundle",

0 commit comments

Comments
 (0)