@@ -51,7 +51,6 @@ def __init__(
5151 if self .world_size == 1 :
5252 self .available = False
5353 self .disabled = True
54- self .stream = None
5554 return
5655 try :
5756 self .nccl = NCCLLibrary (library_path )
@@ -60,7 +59,6 @@ def __init__(
6059 # e.g. in a non-GPU environment
6160 self .available = False
6261 self .disabled = True
63- self .stream = None
6462 return
6563
6664 self .available = True
@@ -98,12 +96,12 @@ def __init__(
9896 with torch .cuda .device (device ):
9997 self .comm : ncclComm_t = self .nccl .ncclCommInitRank (
10098 self .world_size , self .unique_id , self .rank )
101- self .stream = torch .cuda .Stream ()
10299
100+ stream = torch .cuda .current_stream ()
103101 # A small all_reduce for warmup.
104102 data = torch .zeros (1 , device = device )
105103 self .all_reduce (data )
106- self . stream .synchronize ()
104+ stream .synchronize ()
107105 del data
108106
109107 def all_reduce (self ,
@@ -122,7 +120,7 @@ def all_reduce(self,
122120 out_tensor = torch .empty_like (in_tensor )
123121
124122 if stream is None :
125- stream = self . stream
123+ stream = torch . cuda . current_stream ()
126124 self .nccl .ncclAllReduce (buffer_type (in_tensor .data_ptr ()),
127125 buffer_type (out_tensor .data_ptr ()),
128126 in_tensor .numel (),
@@ -144,7 +142,7 @@ def all_gather(self,
144142 f"this nccl communicator is created to work on { self .device } , "
145143 f"but the input tensor is on { input_tensor .device } " )
146144 if stream is None :
147- stream = self . stream
145+ stream = torch . cuda . current_stream ()
148146 self .nccl .ncclAllGather (
149147 buffer_type (input_tensor .data_ptr ()),
150148 buffer_type (output_tensor .data_ptr ()), input_tensor .numel (),
@@ -165,7 +163,7 @@ def reduce_scatter(self,
165163 f"this nccl communicator is created to work on { self .device } , "
166164 f"but the input tensor is on { input_tensor .device } " )
167165 if stream is None :
168- stream = self . stream
166+ stream = torch . cuda . current_stream ()
169167 self .nccl .ncclReduceScatter (
170168 buffer_type (input_tensor .data_ptr ()),
171169 buffer_type (output_tensor .data_ptr ()), output_tensor .numel (),
@@ -180,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
180178 f"this nccl communicator is created to work on { self .device } , "
181179 f"but the input tensor is on { tensor .device } " )
182180 if stream is None :
183- stream = self . stream
181+ stream = torch . cuda . current_stream ()
184182 self .nccl .ncclSend (buffer_type (tensor .data_ptr ()), tensor .numel (),
185183 ncclDataTypeEnum .from_torch (tensor .dtype ), dst ,
186184 self .comm , cudaStream_t (stream .cuda_stream ))
@@ -192,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
192190 f"this nccl communicator is created to work on { self .device } , "
193191 f"but the input tensor is on { tensor .device } " )
194192 if stream is None :
195- stream = self . stream
193+ stream = torch . cuda . current_stream ()
196194 self .nccl .ncclRecv (buffer_type (tensor .data_ptr ()), tensor .numel (),
197195 ncclDataTypeEnum .from_torch (tensor .dtype ), src ,
198196 self .comm , cudaStream_t (stream .cuda_stream ))
@@ -204,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
204202 f"this nccl communicator is created to work on { self .device } , "
205203 f"but the input tensor is on { tensor .device } " )
206204 if stream is None :
207- stream = self . stream
205+ stream = torch . cuda . current_stream ()
208206 if src == self .rank :
209207 sendbuff = buffer_type (tensor .data_ptr ())
210208 # NCCL requires the sender also to have a receive buffer
@@ -217,25 +215,17 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
217215 self .comm , cudaStream_t (stream .cuda_stream ))
218216
219217 @contextmanager
220- def change_state (self ,
221- enable : Optional [bool ] = None ,
222- stream : Optional [torch .cuda .Stream ] = None ):
218+ def change_state (self , enable : Optional [bool ] = None ):
223219 """
224220 A context manager to change the state of the communicator.
225221 """
226222 if enable is None :
227223 # guess a default value when not specified
228224 enable = self .available
229225
230- if stream is None :
231- stream = self .stream
232-
233226 old_disable = self .disabled
234- old_stream = self .stream
235227
236- self .stream = stream
237228 self .disabled = not enable
238229 yield
239230
240231 self .disabled = old_disable
241- self .stream = old_stream
0 commit comments