@@ -4,16 +4,37 @@ import { Map, Set } from 'immutable'
44import type { DataType , Network , TaskProvider } from "@epfml/discojs" ;
55import { defaultTasks } from '@epfml/discojs'
66
7- interface BenchmarkArguments {
7+ type AggregationStrategy = "mean" | "byzantine" | "secure" ;
8+
9+ function parseAggregator ( raw : string ) : AggregationStrategy {
10+ if ( raw === "mean" || raw == "byzantine" || raw == "secure" )
11+ return raw ;
12+ else
13+ throw new Error ( `Aggregator ${ raw } is not supported.` ) ;
14+ }
15+
16+ export interface BenchmarkArguments {
817 provider : TaskProvider < DataType , Network > ;
18+ testID : string
919 numberOfUsers : number
1020 epochs : number
1121 roundDuration : number
1222 batchSize : number
1323 validationSplit : number
24+
25+ // DP
1426 epsilon ?: number
1527 delta ?: number
1628 dpDefaultClippingRadius ?: number
29+ // Aggregator
30+ aggregator : AggregationStrategy
31+ // Byzantine aggregator
32+ clippingRadius ?: number
33+ maxIterations ?: number
34+ beta ?: number
35+ // Secure aggregator
36+ maxShareValue ?: number
37+
1738 save : boolean
1839 host : URL
1940}
@@ -27,15 +48,13 @@ const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs'
2748
2849const unsafeArgs = parse < BenchmarkUnsafeArguments > (
2950 {
51+ testID : { type : String , alias : 'i' , description : 'ID of the testcase' } ,
3052 task : { type : String , alias : 't' , description : 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid' , defaultValue : 'tinder_dog' } ,
3153 numberOfUsers : { type : Number , alias : 'u' , description : 'Number of users' , defaultValue : 2 } ,
3254 epochs : { type : Number , alias : 'e' , description : 'Number of epochs' , defaultValue : 10 } ,
3355 roundDuration : { type : Number , alias : 'r' , description : 'Round duration (in epochs)' , defaultValue : 2 } ,
3456 batchSize : { type : Number , alias : 'b' , description : 'Training batch size' , defaultValue : 10 } ,
3557 validationSplit : { type : Number , alias : 'v' , description : 'Validation dataset ratio' , defaultValue : 0.2 } ,
36- epsilon : { type : Number , alias : 'n' , description : 'Privacy budget' , optional : true , defaultValue : undefined } ,
37- delta : { type : Number , alias : 'd' , description : 'Probability of failure, slack parameter' , optional : true , defaultValue : undefined } ,
38- dpDefaultClippingRadius : { type : Number , alias : 'f' , description : 'Default clipping radius for DP' , optional : true , defaultValue : undefined } ,
3958 save : { type : Boolean , alias : 's' , description : 'Save logs of benchmark' , defaultValue : false } ,
4059 host : {
4160 type : ( raw : string ) => new URL ( raw ) ,
@@ -44,6 +63,22 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
4463 defaultValue : new URL ( "http://localhost:8080" ) ,
4564 } ,
4665
66+ // Aggregator
67+ aggregator : { type : parseAggregator , description : 'Type of weight aggregator' , defaultValue : 'mean' } ,
68+
69+ // Byzantine aggregator
70+ clippingRadius : { type : Number , description : "Clipping radius for centered clipping" , optional : true } ,
71+ maxIterations : { type : Number , description : "Maximum centered clipping iterations" , optional : true } ,
72+ beta : { type : Number , description : "Momentum coefficient to smooth the aggregation over multiple rounds" , optional : true } ,
73+
74+ // Secure aggregator
75+ maxShareValue : { type : Number , description : "Maximum absolute value over all the weights" , optional : true } ,
76+
77+ // Differential Privacy
78+ epsilon : { type : Number , description : 'Privacy budget' , optional : true , defaultValue : undefined } ,
79+ delta : { type : Number , description : 'Probability of failure, slack parameter' , optional : true , defaultValue : undefined } ,
80+ dpDefaultClippingRadius : { type : Number , description : 'Default clipping radius for DP' , optional : true , defaultValue : undefined } ,
81+
4782 help : { type : Boolean , optional : true , alias : 'h' , description : 'Prints this usage guide' }
4883 } ,
4984 {
@@ -88,6 +123,44 @@ export const args: BenchmarkArguments = {
88123 task . trainingInformation . epochs = unsafeArgs . epochs ;
89124 task . trainingInformation . validationSplit = unsafeArgs . validationSplit ;
90125
126+ const { aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs ;
127+
128+ // For aggregators
129+ if ( aggregator !== undefined )
130+ task . trainingInformation . aggregationStrategy = aggregator ;
131+
132+ // For byzantine aggregator
133+ if (
134+ clippingRadius !== undefined &&
135+ maxIterations !== undefined &&
136+ beta !== undefined
137+ ) {
138+ if ( task . trainingInformation . scheme === "local" )
139+ throw new Error ( "Byzantine aggregator is not supported for local training" ) ;
140+ if ( task . trainingInformation . aggregationStrategy !== "byzantine" )
141+ throw new Error ( "Byzantine parameters can be set only when aggregationStrategy is byzantine" ) ;
142+
143+ task . trainingInformation . privacy = {
144+ ...task . trainingInformation . privacy ,
145+ byzantineFaultTolerance : {
146+ clippingRadius,
147+ maxIterations,
148+ beta,
149+ } ,
150+ } ;
151+ }
152+
153+ // For secure aggregator
154+ if ( maxShareValue !== undefined ) {
155+
156+ if ( task . trainingInformation . scheme !== "decentralized" )
157+ throw new Error ( "Secure aggation is only supported for decentralized laerning" )
158+ if ( task . trainingInformation . aggregationStrategy !== "secure" )
159+ throw new Error ( "maxShareValue can be set when aggregationStrategy is secure" ) ;
160+
161+ task . trainingInformation . maxShareValue = maxShareValue ;
162+ }
163+
91164 // For DP
92165 const { dpDefaultClippingRadius, epsilon, delta} = unsafeArgs ;
93166
@@ -102,7 +175,7 @@ export const args: BenchmarkArguments = {
102175 const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1 ;
103176
104177 // for the case where privacy parameters are not defined in the default tasks
105- task . trainingInformation . privacy ??= { }
178+ task . trainingInformation . privacy ??= { } ;
106179 task . trainingInformation . privacy . differentialPrivacy = {
107180 clippingRadius : defaultRadius ,
108181 epsilon : epsilon ,
0 commit comments