Skip to content

Commit de3e6d7

Browse files
committed
Guard out-of-range shift counts in integer opcodes
1 parent 96655c0 commit de3e6d7

2 files changed

Lines changed: 15 additions & 4 deletions

File tree

numexpr/interp_body.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,8 @@
266266
case OP_POW_III: VEC_ARG2(i_dest = (i2 < 0) ? (1 / i1) : (int)pow((double)i1, i2));
267267
case OP_MOD_III: VEC_ARG2(i_dest = i2 == 0 ? 0 :((i1 % i2) + i2) % i2);
268268
case OP_FLOORDIV_III: VEC_ARG2(i_dest = i2 ? (i1 / i2) - ((i1 % i2 != 0) && (i1 < 0 != i2 < 0)) : 0);
269-
case OP_LSHIFT_III: VEC_ARG2(i_dest = i1 << i2);
270-
case OP_RSHIFT_III: VEC_ARG2(i_dest = i1 >> i2);
269+
case OP_LSHIFT_III: VEC_ARG2(i_dest = (unsigned int)i2 < 32 ? i1 << i2 : 0);
270+
case OP_RSHIFT_III: VEC_ARG2(i_dest = i1 >> ((unsigned int)i2 < 32 ? i2 : 31));
271271

272272
case OP_WHERE_IBII: VEC_ARG3(i_dest = b1 ? i2 : i3);
273273
//Bitwise ops
@@ -292,8 +292,8 @@
292292
#endif
293293
case OP_MOD_LLL: VEC_ARG2(l_dest = l2 == 0 ? 0 :((l1 % l2) + l2) % l2);
294294
case OP_FLOORDIV_LLL: VEC_ARG2(l_dest = l2 ? (l1 / l2) - ((l1 % l2 != 0) && (l1 < 0 != l2 < 0)): 0);
295-
case OP_LSHIFT_LLL: VEC_ARG2(l_dest = l1 << l2);
296-
case OP_RSHIFT_LLL: VEC_ARG2(l_dest = l1 >> l2);
295+
case OP_LSHIFT_LLL: VEC_ARG2(l_dest = (unsigned long long)l2 < 64 ? l1 << l2 : 0);
296+
case OP_RSHIFT_LLL: VEC_ARG2(l_dest = l1 >> ((unsigned long long)l2 < 64 ? l2 : 63));
297297

298298
case OP_WHERE_LBLL: VEC_ARG3(l_dest = b1 ? l2 : l3);
299299
//Bitwise ops

numexpr/tests/test_numexpr.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,17 @@ def test_right_shift(self):
457457
x = arange(10, dtype='i4')
458458
assert_array_equal(evaluate("x>>2"), x >> 2)
459459

460+
def test_shift_out_of_range(self):
461+
# Shift counts that are negative or >= the operand width are
462+
# undefined behavior in C. Match NumPy, which treats them as a
463+
# full shift (0 for <<, sign fill for >>).
464+
for dtype in ('i4', 'i8'):
465+
x = array([5, -5, 0], dtype=dtype)
466+
for count in (-1, 64, 200):
467+
y = array([count] * len(x), dtype=dtype)
468+
assert_array_equal(evaluate("x << y"), x << y)
469+
assert_array_equal(evaluate("x >> y"), x >> y)
470+
460471
# PyTables uses __nonzero__ among ExpressionNode objects internally
461472
# so this should be commented out for the moment. See #24.
462473
def test_boolean_operator(self):

0 commit comments

Comments
 (0)