Skip to content

Commit de71d06

Browse files
authored
Merge pull request #1080 from epfml/stress-testing
Add automated test runner and visualization script for stress testing
2 parents c2c2111 + 2b6cd20 commit de71d06

21 files changed

Lines changed: 759 additions & 114 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ dist/
1414
.idea/
1515
.vscode/
1616
*.DS_Store
17+
18+
# python venv
19+
.venv/

cli/src/args.ts

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,37 @@ import { Map, Set } from 'immutable'
44
import type { DataType, Network, TaskProvider } from "@epfml/discojs";
55
import { defaultTasks } from '@epfml/discojs'
66

7-
interface BenchmarkArguments {
7+
type AggregationStrategy = "mean" | "byzantine" | "secure";
8+
9+
function parseAggregator(raw: string): AggregationStrategy{
10+
if (raw === "mean" || raw == "byzantine" || raw == "secure")
11+
return raw;
12+
else
13+
throw new Error(`Aggregator ${raw} is not supported.`);
14+
}
15+
16+
export interface BenchmarkArguments {
817
provider: TaskProvider<DataType, Network>;
18+
testID: string
919
numberOfUsers: number
1020
epochs: number
1121
roundDuration: number
1222
batchSize: number
1323
validationSplit: number
24+
25+
// DP
1426
epsilon?: number
1527
delta?: number
1628
dpDefaultClippingRadius?: number
29+
// Aggregator
30+
aggregator: AggregationStrategy
31+
// Byzantine aggregator
32+
clippingRadius?: number
33+
maxIterations?: number
34+
beta?: number
35+
// Secure aggregator
36+
maxShareValue?: number
37+
1738
save: boolean
1839
host: URL
1940
}
@@ -27,15 +48,13 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'
2748

2849
const unsafeArgs = parse<BenchmarkUnsafeArguments>(
2950
{
51+
testID: { type: String, alias: 'i', description: 'ID of the testcase' },
3052
task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' },
3153
numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 },
3254
epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 },
3355
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
3456
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
3557
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
36-
epsilon: { type: Number, alias: 'n', description: 'Privacy budget', optional: true, defaultValue: undefined},
37-
delta: { type: Number, alias: 'd', description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
38-
dpDefaultClippingRadius: {type: Number, alias: 'f', description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
3958
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
4059
host: {
4160
type: (raw: string) => new URL(raw),
@@ -44,6 +63,22 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
4463
defaultValue: new URL("http://localhost:8080"),
4564
},
4665

66+
// Aggregator
67+
aggregator: { type: parseAggregator, description: 'Type of weight aggregator', defaultValue: 'mean' },
68+
69+
// Byzantine aggregator
70+
clippingRadius: { type: Number, description: "Clipping radius for centered clipping", optional: true },
71+
maxIterations: { type: Number, description: "Maximum centered clipping iterations", optional: true },
72+
beta: { type: Number, description: "Momentum coefficient to smooth the aggregation over multiple rounds", optional: true },
73+
74+
// Secure aggregator
75+
maxShareValue: { type: Number, description: "Maximum absolute value over all the weights", optional: true },
76+
77+
// Differential Privacy
78+
epsilon: { type: Number, description: 'Privacy budget', optional: true, defaultValue: undefined},
79+
delta: { type: Number, description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined},
80+
dpDefaultClippingRadius: {type: Number, description: 'Default clipping radius for DP', optional: true, defaultValue: undefined},
81+
4782
help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }
4883
},
4984
{
@@ -88,6 +123,44 @@ export const args: BenchmarkArguments = {
88123
task.trainingInformation.epochs = unsafeArgs.epochs;
89124
task.trainingInformation.validationSplit = unsafeArgs.validationSplit;
90125

126+
const {aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs;
127+
128+
// For aggregators
129+
if (aggregator !== undefined)
130+
task.trainingInformation.aggregationStrategy = aggregator;
131+
132+
// For byzantine aggregator
133+
if (
134+
clippingRadius !== undefined &&
135+
maxIterations !== undefined &&
136+
beta !== undefined
137+
){
138+
if (task.trainingInformation.scheme === "local")
139+
throw new Error("Byzantine aggregator is not supported for local training");
140+
if (task.trainingInformation.aggregationStrategy !== "byzantine")
141+
throw new Error("Byzantine parameters can be set only when aggregationStrategy is byzantine");
142+
143+
task.trainingInformation.privacy = {
144+
...task.trainingInformation.privacy,
145+
byzantineFaultTolerance: {
146+
clippingRadius,
147+
maxIterations,
148+
beta,
149+
},
150+
};
151+
}
152+
153+
// For secure aggregator
154+
if (maxShareValue !== undefined){
155+
156+
if (task.trainingInformation.scheme !== "decentralized")
157+
throw new Error("Secure aggation is only supported for decentralized laerning")
158+
if (task.trainingInformation.aggregationStrategy !== "secure")
159+
throw new Error("maxShareValue can be set when aggregationStrategy is secure");
160+
161+
task.trainingInformation.maxShareValue = maxShareValue;
162+
}
163+
91164
// For DP
92165
const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs;
93166

@@ -102,7 +175,7 @@ export const args: BenchmarkArguments = {
102175
const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1;
103176

104177
// for the case where privacy parameters are not defined in the default tasks
105-
task.trainingInformation.privacy ??= {}
178+
task.trainingInformation.privacy ??= {};
106179
task.trainingInformation.privacy.differentialPrivacy = {
107180
clippingRadius: defaultRadius,
108181
epsilon: epsilon,

cli/src/cli.ts

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ import "@tensorflow/tfjs-node"
33

44
import { List, Range } from 'immutable'
55
import fs from 'node:fs/promises'
6+
import { createWriteStream } from "node:fs";
7+
import path from "node:path";
68

79
import type {
810
Dataset,
911
DataFormat,
1012
DataType,
11-
RoundLogs,
13+
SummaryLogs,
1214
Task,
1315
TaskProvider,
1416
Network,
@@ -17,49 +19,103 @@ import { Disco, aggregator as aggregators, client as clients } from '@epfml/disc
1719

1820
import { getTaskData } from './data.js'
1921
import { args } from './args.js'
22+
import { makeUserLogFile } from "./user_log.js";
23+
import type { UserLogFile } from "./user_log.js";
2024

21-
// Array.fromAsync not yet widely used (2024)
22-
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
23-
const ret: T[] = [];
24-
for await (const e of iter) ret.push(e);
25-
return ret;
26-
}
2725

2826
async function runUser<D extends DataType, N extends Network>(
2927
task: Task<D, N>,
3028
url: URL,
3129
data: Dataset<DataFormat.Raw[D]>,
32-
): Promise<List<RoundLogs>> {
30+
userIndex: number,
31+
numberOfUsers: number,
32+
): Promise<List<SummaryLogs>> {
3333
// cast as typescript isn't good with generics
3434
const trainingScheme = task.trainingInformation.scheme as N
3535
const aggregator = aggregators.getAggregator(task)
3636
const client = clients.getClient(trainingScheme, url, task, aggregator)
3737
const disco = new Disco(task, client, { scheme: trainingScheme });
3838

39-
const logs = List(await arrayFromAsync(disco.trainByRound(data)));
40-
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish
41-
await disco.close();
42-
return logs;
39+
const dir = path.join(".", `${args.testID}`);
40+
await fs.mkdir(dir, { recursive: true });
41+
const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`);
42+
43+
const finalLog: SummaryLogs[] = [];
44+
// create a write stream that saves learning logs during the train
45+
let jsonStream: ReturnType<typeof createWriteStream> | null = null;
46+
47+
if (args.save){
48+
jsonStream = createWriteStream(streamPath, {flags: "w"});
49+
}
50+
51+
try{
52+
for await (const log of disco.trainSummary(data)){
53+
finalLog.push(log);
54+
55+
if (jsonStream){
56+
jsonStream.write(JSON.stringify(log) + "\n");
57+
}
58+
}
59+
60+
await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish
61+
62+
// saving the entire per-user logs
63+
if (args.save) {
64+
const finalPath = path.join(dir, `client${userIndex}_local_log.json`);
65+
66+
const userLog: UserLogFile = makeUserLogFile(task, numberOfUsers, userIndex, client.ownId, finalLog);
67+
68+
await fs.writeFile(finalPath, JSON.stringify(userLog, null, 2));
69+
}
70+
71+
return List(finalLog);
72+
}catch(err){
73+
console.error(`Run user failed for client ${userIndex}: `, err);
74+
throw err;
75+
}finally{
76+
try{
77+
if (jsonStream){
78+
jsonStream.end();
79+
80+
await new Promise<void>((resolve, reject) => {
81+
jsonStream.once("finish", resolve);
82+
jsonStream.once("error", reject);
83+
});
84+
}
85+
}catch(err){
86+
console.error(`failed to close log stream for client ${userIndex}: `, err);
87+
}
88+
89+
try{
90+
await disco.close();
91+
}catch(err){
92+
console.error(`failed to close disco for client ${userIndex}: `, err);
93+
}
94+
}
4395
}
4496

4597
async function main<D extends DataType, N extends Network>(
4698
provider: TaskProvider<D, N>,
4799
numberOfUsers: number,
48100
): Promise<void> {
49101
const task = await provider.getTask();
102+
console.log(`Test ID: ${args.testID}`)
50103
console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`)
51104
console.log({ args })
52105

53106
const dataSplits = await Promise.all(
54107
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
55108
)
56109
const logs = await Promise.all(
57-
dataSplits.map(async data => await runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>))
110+
dataSplits.map((data, i) => runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>, i, numberOfUsers))
58111
)
59112

60113
if (args.save) {
61-
const fileName = `${task.id}_${numberOfUsers}users.csv`;
62-
await fs.writeFile(fileName, JSON.stringify(logs, null, 2));
114+
const dir = path.join(".", `${args.testID}`, `${task.id}`);
115+
await fs.mkdir(dir, { recursive: true });
116+
117+
const filePath = path.join(dir, `${task.id}_${numberOfUsers}users.json`);
118+
await fs.writeFile(filePath, JSON.stringify(logs, null, 2));
63119
}
64120
}
65121

0 commit comments

Comments
 (0)