Grouping and Aggregations With Java Streams

When we group elements from a list, we can aggregate the fields of the group elements to perform meaningful operations that help us analyze the data. Some examples are addition, averages, or max/min values. These aggregations of single fields can be easily done with Java Streams and Collectors. The documentation provides simple examples of how to do these types of calculations.

However, there are more sophisticated aggregations like weighted averages, geometric means. Additionally, there might be the need to do simultaneous aggregations of several fields. In this article, we are going to show a straightforward path to solve these kinds of problems using Java Streams. Using this framework allows us to process large amounts of data quickly and efficiencies.

We’ll assume that the reader has a basic understanding of Java Streams and the utility Collectors class.

Problem Layout

Let’s consider a simple example to showcase the type of issues that we want to solve. We’ll make it very generic so we can easily generalize it. Let’s consider a list of TaxEntry entities that it’s defined by the following code:

public class TaxEntry {

    private String state;
    private String city;
    private int numEntries;
    private double price;
    //Constructors, getters, hashCode, equals etc
}

It is very simple to compute the total number of entries for a given city:

Map<String, Integer> totalNumEntriesByCity = 
              taxes.stream().collect(Collectors.groupingBy(TaxEntry::getCity, 
                                                           Collectors.summingInt(TaxEntry::getNumEntries)));

Collectors.groupingBy Takes two parameters: a classifier function to do the grouping and a Collector that does the downstream aggregation for all the elements that belong to a given group. We use TaxEntry::getCity as the classifier function. For the downstream, we use Collectors::summingInt which returns a Collector that sums the number of tax entries that we get for each grouped element.

Things are a little more complicated if we try to find compound groupings. For example, with the previous example, the total number of entries for a given state and city. There are several ways to do this, but a very straightforward approach is first to define:

record StateCityGroup(String state, String city) {}

Notice that we’re using a Java record, which is a concise way to define an immutable class. Furthermore, the Java compiler generates for us field accessor methods, hashCodeequals, and toString implementations. With this in hand, the solution now is simple:

Map<StateCityGroup, Integer> totalNumEntriesForStateCity = 
                    taxes.stream().collect(groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()), 
                                                      Collectors.summingInt(TaxEntrySimple::getNumEntries))
                                          );

For Collectors::groupingBy we set the classifier function using a lambda expression that creates a new StateCityGroup record that encapsulates each state-city. The downstream Collector is the same as before.

Note: For the sake of conciseness, in the code samples, we are going to assume static imports for all the methods of the Collectors class, so we don’t have to show their class qualification.

Where things start to get more complicated is if we want to do several aggregations simultaneously. For example, find the sum of the number of entries and the average price for a given state and city. The library does not provide a simple solution to this problem.

To begin untangling this issue, we take a cue from the previous aggregation and define a record that encapsulates all the fields that need to be aggregated:

record TaxEntryAggregation (int totalNumEntries, double averagePrice ) {}

Now, how do we do the aggregation simultaneously for the two fields? There is always the possibility of doing the stream collection twice to find each of the aggregations separately, as it’s suggested in the following code:

Map<StateCityGroup, TaxEntryAggregation> aggregationByStateCity = taxes.stream().collect(
           groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()),
                      collectingAndThen(Collectors.toList(), 
                                        list -> {int entries = list.stream().collect(
                                                                   summingInt(TaxEntrySimple::getNumEntries));
                                                 double priceAverage = list.stream().collect(
                                                                   averagingDouble(TaxEntrySimple::getPrice));
                                                 return new TaxEntryAggregation(entries, priceAverage);})));

The grouping is done as before, but for the downstream, we do the aggregation using Collectors::collectingAndThen (line 3). This function takes two parameters:

  • The download stream from the initial grouping that we convert into a list (using Collectors::toList() in line 3)
  • Finisher function (lines 4–9) where we use a lambda expression to create two different streams from the previous list to do the aggregations and return them combined in a new TaxEntryAggregation record

Imagine that we wanted to do more field aggregations simultaneously. We will need to increase accordingly the number of streams from the downstream list. The code becomes, inefficient, very repetitive, and less than desirable. We should look for better alternatives.

Also, the problems don’t end here, and in general, we’re constrained on the types of aggregations that we can do with the Collectors helper class. Their methods, summing*, averaging*, and summarizing*, provide support only for integer, long, and double native types. What do we do if we have more sophisticated types like BigInteger or BigDecimal?

To add insult to injury the summarizing* methods only provide summary statistics for, min, max, count, sum, and average. What if we want to perform more sophisticated calculations such as weighted averages or geometric means?

Some people will argue that we can always write custom collectors, but this requires knowing the collector interface and a good understanding of the stream collector flow. It’s more straightforward to use built-in collectors provided with the utility methods in the Collectors class. In the next section, we’ll show a couple of strategies on how to accomplish this.

Complex Multiple Aggregations: A Resolution Path

Let’s consider a simple example that will highlight the challenges that we have mentioned in the previous section. Suppose that we have the following entity:

public class TaxEntry {
    private String state;
    private String city;
    private BigDecimal rate;
    private BigDecimal price;
    record StateCityGroup(String state, String city) {
    }
    //Constructors, getters, hashCode/equals etc
}

We start asking how for each distinct state-city pair the total count of entries we can find the total sum of the product of rate and price (∑(rate * price)). Notice that we are doing a multifield aggregation using BigDecimal.

As we did in the previous section, we define a class that encapsulates the aggregation:

record RatePriceAggregation(int count, BigDecimal ratePrice) {}

It might seem surprising at first, but a straightforward solution to groupings that are followed by simple aggregations is to use Collectors::toMap. Let’s see how we would do it:

Map<StateCityGroup, RatePriceAggregation> mapAggregation = taxes.stream().collect(
      toMap(p -> new StateCityGroup(p.getState(), p.getCity()), 
            p -> new RatePriceAggregation(1, p.getRate().multiply(p.getPrice())), 
            (u1,u2) -> new RatePriceAggregation( u1.count() + u2.count(), u1.ratePrice().add(u2.ratePrice()))
            ));

The Collectors::toMap (line 2) takes three parameters, we do the following implementation:

  • The first parameter is a lambda expression to generate the keys of the map. This function creates StateCityGroupas keys to the map. This will group the elements by state and city (line 2).
  • The second parameter produces the values ​​of the map. In our case, we create a RatePriceAggregation initialized with a count of 1 and the product of rate and price (line 3).
  • Finally, the last parameter is a BinaryOperator to merge cases where multiple elements map to the same state-city key. We sum the counts and prices to do our aggregation (line 4).

Let’s demonstrate how this will work setting up some sample data:

List<TaxEntry> taxes = Arrays.asList(
                          new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.2), BigDecimal.valueOf(20.0)), 
                          new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.4), BigDecimal.valueOf(10.0)), 
                          new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.6), BigDecimal.valueOf(10.0)), 
                          new TaxEntry("Florida", "Orlando", BigDecimal.valueOf(0.3), BigDecimal.valueOf(13.0)));

To get the results for New York from the previous code sample is trivial:

System.out.println("New York: " + mapAggregation.get(new StateCityGroup("New York", "NYC")));

This prints:

New York: RatePriceAggregation[count=3, ratePrice=14.00]

This is a straightforward implementation that determines the grouping and aggregation of multiple fields and non-primitive data types (BigDecimal in our case). However, it has the drawback that it does not have any finalizers that allow you to perform extra operations. For example, you can’t do averages of any kind.

To showcase this issue, let’s consider a more complex problem. Suppose that we want to find the weighted average of the rate-price, and the sum of all the prices for each state and city pair. In particular, to find the weighted average, we need to calculate the sum of the product of the rate and price for all the entries that belong to each state-city pair, and then divide by the total number of entries n for each case: 1/n ∑(rate * price).

To tackle this problem we start defining a record that comprises the aggregation:

record TaxEntryAggregation(int count, BigDecimal weightedAveragePrice, BigDecimal totalPrice) {}

With this in hand, we can do the following implementation:

Map<StateCityGroup, TaxEntryAggregation> groupByAggregation = taxes.stream().collect(
    groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()), 
               mapping(p -> new TaxEntryAggregation(1, p.getRate().multiply(p.getPrice()), p.getPrice()), 
                       collectingAndThen(reducing(new TaxEntryAggregation(0, BigDecimal.ZERO, BigDecimal.ZERO),
                                                  (u1,u2) -> new TaxEntryAggregation(u1.count() + u2.count(),
                                                      u1.weightedAveragePrice().add(u2.weightedAveragePrice()), 
                                                      u1.totalPrice().add(u2.totalPrice()))
                                                  ),
                                         u -> new TaxEntryAggregation(u.count(), 
                                                 u.weightedAveragePrice().divide(BigDecimal.valueOf(u.count()),
                                                                                 2, RoundingMode.HALF_DOWN), 
                                                 u.totalPrice())
                                         )
                      )
              ));

We can see that the code is somewhat more complicated, but allows us to get the solution we are looking for. We’ll follow it more in detail:

  • Collectors::groupingBy(line 2):
    1. For the classification function, we create a StateCityGroup record
    2. For the downstream, we invoke Collectors::mapping(line 3):
      • For the first parameter, the mapper that we apply to the input elements transforms the grouped state-city tax records to new TaxEntryAggregation Entries that assign the initial count to 1, multiply the rate with price, and set the price (line 3).
      • For the downstream, we invoke Collectors::collectingAndThen(line 4), and as we’ll see, this will allow us to apply to the downstream collector a finishing transformation.
        • Invoke Collectors::reducing(line 4)
          1. Create a default TaxEntryAggregation to cover the cases where there are no downstream elements (line 4).
          2. Lambda expression to do the reduction and return a new TaxEntryAggregation that has the aggregations of the fields (line 5, 6 7)
        • Perform the finishing transformation calculating the averages using the count calculated in the previous reduction and returning the final TaxEntryAggregation (lines 9, 10, 11).

We see that this implementation not only allows us to do multiple field aggregations simultaneously but can also perform complex calculations in several stages.

This can be easily generalized to solve more complex problems. The path is straightforward: define a record that encapsulates all the fields that need to be aggregated, use Collectors::mapping to initialize the records, and then apply Collectors::collectingAndThen to do the reduction and final aggregation.

As before we can get the aggregations for New York:

System.out.println("Finished aggregation: " + groupByAggregation.get(new StateCityGroup("New York", "NYC")));

We get the results:

Finished aggregation: TaxEntryAggregation[count=3, weightedAveragePrice=4.67, totalPrice=40.0]

It is also worth pointing out that because TaxEntryAggregation is a Java recordit’s immutable, so the calculation can be parallelized using the support provided by the stream collector’s library.

Conclusion

We have shown a couple of strategies to do complex multi-field groupings with aggregations that include non-primitive data types with multi and cross-field calculations. This is for a list of records using Java streams and the Collectors API, so it provides us the ability to process huge amounts of data quickly and efficiently.

.

Leave a Comment