|
3 | 3 |
|
4 | 4 | from sqlalchemy.orm import Query, sessionmaker |
5 | 5 |
|
6 | | -from jsonapi_query.database.sqlalchemy import QueryMixin |
| 6 | +from jsonapi_query.database.sqlalchemy import group_and_remove, QueryMixin |
7 | 7 | from tests.sqlalchemy import ( |
8 | 8 | BaseSQLAlchemyTestCase, Category, Person, Product, School, Student) |
9 | 9 |
|
@@ -287,14 +287,47 @@ def test_include_ambiguous_join_conditions(self): |
287 | 287 |
|
288 | 288 | def test_include_does_not_restrict_primary_output(self): |
289 | 289 | """Test including a relationship does not restrict primary output.""" |
290 | | - p = Product(name='Tst') |
| 290 | + a = Category(name='Category A') |
| 291 | + self.session.add(a) |
| 292 | + b = Category(name='Category B', category_id=1) |
| 293 | + self.session.add(b) |
| 294 | + p = Product(primary_category_id=1, secondary_category_id=2, name='Tst') |
| 295 | + self.session.add(p) |
| 296 | + |
| 297 | + p = Product(name='Tst 2') |
291 | 298 | self.session.add(p) |
292 | 299 |
|
293 | 300 | models = self.session.query(Product).include( |
294 | 301 | [Product.primary_category]).all() |
295 | | - self.assertTrue(len(models) == 1) |
| 302 | + self.assertTrue(len(models) == 2) |
296 | 303 |
|
297 | 304 | def test_include_no_mappers(self): |
298 | 305 | """Test including an empty set of relationships.""" |
299 | 306 | models = self.session.query(Person).include([]).first() |
300 | 307 | self.assertTrue(isinstance(models, Person)) |
| 308 | + |
| 309 | + |
| 310 | +class UtilitySQLAlchemyTestCase(BaseDatabaseSQLAlchemyTests): |
| 311 | + """Test handling a query's output.""" |
| 312 | + |
| 313 | + def test_group_and_remove(self): |
| 314 | + """Test group and remove utility function.""" |
| 315 | + a = Category(name='Category A') |
| 316 | + self.session.add(a) |
| 317 | + b = Category(name='Category B', category_id=1) |
| 318 | + self.session.add(b) |
| 319 | + p = Product(primary_category_id=1, secondary_category_id=2, name='Tst') |
| 320 | + self.session.add(p) |
| 321 | + |
| 322 | + p = Product(name='Tst 2') |
| 323 | + self.session.add(p) |
| 324 | + |
| 325 | + # Returns two products and one category. |
| 326 | + items = self.session.query(Product).include( |
| 327 | + [Product.primary_category]).all() |
| 328 | + self.assertTrue(len(items) == 2) |
| 329 | + |
| 330 | + output = group_and_remove(items, [Product, Category]) |
| 331 | + self.assertTrue(len(output) == 2) |
| 332 | + self.assertTrue(len(output[0]) == 2) |
| 333 | + self.assertTrue(len(output[1]) == 1) |
0 commit comments