Skip to content

Commit ee6e6d5

Browse files
committed
check
1 parent 166ff39 commit ee6e6d5

4 files changed

Lines changed: 20 additions & 0 deletions

File tree

paddle/phi/kernels/cpu/gather_tree_kernel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ void GatherTreeKernel(const Context &dev_ctx,
4949
out_data[idx] = ids_data[idx];
5050
auto parent = parents_data[idx];
5151
for (int step = max_length - 2; step >= 0; step--) {
52+
PADDLE_ENFORCE_LT(
53+
parent,
54+
beam_size,
55+
paddle::platform::errors::InvalidArgument(
56+
"The parents must be less than beam size, but recieved"
57+
"parents %d is greater than or equal to beam size %d. ",
58+
parent,
59+
beam_size));
60+
5261
idx = step * batch_size * beam_size + batch * beam_size;
5362
out_data[idx + beam] = ids_data[idx + parent];
5463
parent = parents_data[idx + parent];

paddle/phi/kernels/gather_tree_kernel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "paddle/fluid/platform/enforce.h"
1718
#include "paddle/phi/core/dense_tensor.h"
1819

1920
namespace phi {

paddle/phi/kernels/gpu/gather_tree_kernel.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ __global__ void GatherTree(const T *ids_data,
3535
out_data[idx] = ids_data[idx];
3636
auto parent = parents_data[idx];
3737
for (int step = max_length - 2; step >= 0; step--) {
38+
assert(parent < beam_size &&
39+
"The parents of gather_tree op must be less than beam size. ");
40+
3841
idx = step * batch_size * beam_size + batch * beam_size;
3942
out_data[idx + beam] = ids_data[idx + parent];
4043
parent = parents_data[idx + parent];

python/paddle/nn/functional/extension.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ def gather_tree(ids, parents):
303303
# [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]]
304304
305305
"""
306+
if ids.ndim != 3:
307+
raise ValueError(
308+
"The input ids must be a 3D tensor with shape [length, batch_size, beam_size]"
309+
)
310+
if ids.ndim != parents.ndim:
311+
raise ValueError("The ids's shape must be the same as parents' shape. ")
312+
306313
if in_dygraph_mode():
307314
return _C_ops.gather_tree(ids, parents)
308315
else:

0 commit comments

Comments
 (0)