Skip to content

Commit 133f33b

Browse files
committed
Fix 3 issues from code review: variable scoping, squeeze edge case, Gloo P2P
1. Initialize _use_cached=False at top of _generate_graph() to prevent potential UnboundLocalError if code is refactored (escn_md.py). 2. Replace squeeze() with reshape(-1) in _balance_assignments to handle the edge case where src_atoms has exactly 1 element (squeeze() would produce a 0-d tensor that cannot be sliced) (graph_parallel.py). 3. Skip zero-length P2P ops in Gloo fallback for _safe_all_to_all and start_all_to_all_collect to avoid potential hangs on some PyTorch versions (graph_parallel.py).
1 parent c1d92ee commit 133f33b

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

src/fairchem/core/models/uma/escn_md.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ def _generate_graph(self, data_dict):
623623
data_dict["gp_node_offset"] = 0
624624
node_partition = None
625625
rank_assignments = None
626+
_use_cached = False
626627
if gp_utils.initialized():
627628
# create the partitions
628629
atomic_numbers_full = data_dict["atomic_numbers_full"]

src/fairchem/core/models/uma/graph_parallel.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ def _safe_all_to_all(
4343
for r in range(world_size):
4444
if r == rank:
4545
# Local copy
46-
output_list[r].copy_(input_list[r])
47-
else:
48-
# Send our data to rank r
46+
if input_list[r].numel() > 0:
47+
output_list[r].copy_(input_list[r])
48+
elif input_list[r].numel() > 0 or output_list[r].numel() > 0:
49+
# Skip zero-length P2P ops to avoid potential hangs
4950
ops.append(dist.P2POp(dist.isend, input_list[r], r, group=group))
50-
# Receive data from rank r
5151
ops.append(dist.P2POp(dist.irecv, output_list[r], r, group=group))
5252
if ops:
5353
reqs = dist.batch_isend_irecv(ops)
@@ -248,11 +248,14 @@ def _balance_assignments(
248248

249249
if pos is not None and centroids is not None:
250250
# Move atoms closest to the destination centroid
251-
# to preserve spatial locality
251+
# to preserve spatial locality.
252+
# Use reshape(-1) instead of squeeze() to handle the case
253+
# where src_atoms has exactly 1 element (squeeze() would
254+
# produce a 0-d tensor that cannot be sliced).
252255
dists_to_dst = torch.cdist(
253256
pos[src_atoms].unsqueeze(0),
254257
centroids[dst_rank].unsqueeze(0).unsqueeze(0),
255-
).squeeze()
258+
).reshape(-1)
256259
_, closest_order = dists_to_dst.sort()
257260
atoms_to_move = src_atoms[closest_order[:n_move]]
258261
else:
@@ -782,8 +785,10 @@ def start_all_to_all_collect(
782785
ops = []
783786
for r in range(world_size):
784787
if r == rank:
785-
recv_list[r].copy_(send_list[r])
786-
else:
788+
if send_list[r].numel() > 0:
789+
recv_list[r].copy_(send_list[r])
790+
elif send_list[r].numel() > 0 or recv_list[r].numel() > 0:
791+
# Skip zero-length P2P ops to avoid potential hangs
787792
ops.append(dist.P2POp(dist.isend, send_list[r], r, group=gp_group))
788793
ops.append(dist.P2POp(dist.irecv, recv_list[r], r, group=gp_group))
789794
if ops:

0 commit comments

Comments
 (0)