Friday, April 03, 2009

Higher order functions in Java

I started learning a Scala sometime towards the beginning of this year. At a certain point during this period, I would find Java, my programming language at work (then and now), incredibly tedious to code in, mainly because of its verbosity. I remember thinking that the girlfriend metaphor described here summed up my situation quite aptly.

Since then, things have settled down somewhat (between the metaphorical wife and girlfriend), but I notice that my programming style has changed subtly to become more "Scala-like". Some of it is just smart programming, such as reorganizing method calls with dependencies on each other's side effects into stateless function calls which always return a copy of the original object. The other thing I have begun to notice is the huge number of for-loops that I end up writing in application code, mainly for doing some pretty mundane things such as filtering collections and converting a collection of one type to one of a different type.

Scala and Python (the other two languages that I know enough to code in) both have higher order functions which convert these for-loops into logical one-liners. Java does not have built-in language support for this, but the Apache commons-collections project and its generics-enabled cousin from Larvalabs both have methods in CollectionUtils that do.

However, quite a few of the CollectionUtils methods operate on collections in place and mutate them, so there is still the problem of unintended side effects if you are not careful, and potentially hard to read code if you are. To get around this, I decided to write my own versions that return a copy of the transformed collections. Unlike CollectionUtils, my methods operate on Lists and mimic the behavior of their Scala namesakes.

I do find some of the methods in CollectionUtils useful, such as find, exists and forAllDo. There are a bunch of nice recipes related to the Commons Collection classes in Philip Senger's blog, and I found this recipe dealing with functors particularly interesting.

The ListUtils class

The ListUtils class is a class that exposes a bunch of static methods to do various things with Lists, List of Lists, etc. The code is shown below. I describe each group in a little more detail below.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
// Source: src/main/java/com/mycompany/utils/ListUtils.java
package com.mycompany.utils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.collections15.Predicate;
import org.apache.commons.collections15.Transformer;

/**
 * Static Utility Methods that operate on Lists. Inspired by List methods
 * available natively in Scala and Python.
 */
public class ListUtils {

  /**
   * Return a List made up of elements from another List that pass through
   * the given predicate. For example, {1,2,3,4,5}.filter(x => (x % 2 == 0))
   * would yield {2,4}.
   * @param <E> the type of the element in the List.
   * @param input the List of E.
   * @param predicate the Predicate on E to decide if it should be returned.
   * @return a List of E.
   */
  public static <E> List<E> filter(List<E> input, Predicate<E> predicate) {
    if (input == null) {
      return null;
    }
    List<E> output = new ArrayList<E>();
    for (E e : input) {
      if (predicate.evaluate(e)) {
        output.add(e);
      }
    }
    return output;
  }
  
  /**
   * Apply a transformation to each element of the input List, resulting in 
   * a List of objects of the same or different type. For example, 
   * {1,2,3}.map(x => "s" + x) would yield {"s1","s2","s3"}.
   * @param <E> the input type.
   * @param <O> the output type.
   * @param input the List of E.
   * @param transformer transformer to transform E to O.
   * @return the List of O.
   */
  public static <E,O> List<O> map(List<E> input, Transformer<E,O> transformer) {
    if (input == null) {
      return null;
    }
    List<O> output = new ArrayList<O>();
    for (E e : input) {
      O o = transformer.transform(e);
      output.add(o);
    }
    return output;
  }

  /**
   * The effect of this operation is the same as recursively combining 
   * all but the last element with the last element. For example,
   * {1,2,3,4,5}.reduceLeft((x,y) => x + y) would be the same as 
   * ((((1 + 2) + 3) + 4) + 5). The transformer takes a pair of E in
   * List<E> and returns a transformed reduction E.
   * @param <E> the type of the List element.
   * @param input the List of E to reduce.
   * @param reducer the reducing function modelled as a Transformer.
   * @return the reduced E value.
   */
  public static <E> E reduceLeft(List<E> input, Transformer<List<E>,E> reducer) {
    if (input == null) {
      return null;
    }
    E reduced = null;
    int size = input.size();
    for (int i = 0; i < size; i++) {
      if (i == 0) {
        reduced = input.get(i);
        continue;
      }
      reduced = (E) reducer.transform(pack(reduced, input.get(i)));
    }
    return reduced;
  }

  /**
   * The effect of this operation is the same as recursively combining the
   * first element of the input list with the results of combining the rest
   * of the list. For example, {1,2,3,4,5}.reduceRight((x,y) => x + y)
   * would be the same as (1 + (2 + (3 + (4 + 5)))). The transformer takes 
   * a pair of E and returns the reduced value E.
   * @param <E> the type of the List element.
   * @param input the List of E to reduce.
   * @param reducer the reducer function modelled as a Transformer.
   * @return the reduced E value.
   */
  public static <E> E reduceRight(List<E> input, Transformer<List<E>,E> reducer) {
    E reduced = null;
    int size = input.size();
    for (int i = size - 1; i >= 0; i--) {
      if (i == size - 1) {
        reduced = input.get(i);
        continue;
      }
      reduced = (E) reducer.transform(pack(input.get(i), reduced));
    }
    return reduced;
  }

  /**
   * Same as reduceLeft, but the reduced value is primed with a specified
   * initial value.
   * @param <E> the type of the List element.
   * @param input the List of E.
   * @param initialValue the initial value to start reducing from.
   * @param folder the folding transformer.
   * @return the folded value of E.
   */
  public static <E> E foldLeft(List<E> input, E initialValue, 
      Transformer<List<E>,E> folder) {
    int size = input.size();
    E folded = initialValue;
    for (int i = 0; i < size; i++) {
      folded = (E) folder.transform(pack(folded, input.get(i)));
    }
    return folded;
  }
  
  /**
   * Same as reduceRight, but the reduced value is primed with a specified
   * initial value.
   * @param <E> the type of the List element.
   * @param input the List of E.
   * @param initialValue the initial value for the reduction.
   * @param folder the folding transformer.
   * @return the folded value of E.
   */
  public static <E> E foldRight(List<E> input, E initialValue, 
      Transformer<List<E>,E> folder) {
    int size = input.size();
    E folded = initialValue;
    for (int i = size - 1; i >= 0; i--) {
      folded = (E) folder.transform(pack(input.get(i), folded));
    }
    return folded;
  }
  
  /**
   * Partitions an input List into two Lists. The first list contains the
   * elements where the Predicate returns true, and the second list contains
   * the elements where the Predicate returns false. For example, the call:
   * {1,2,3,4,5}.partition(x => (x % 2 == 0)) will return the lists {2,4} 
   * and {1,3,5}.
   * @param <E> the type of the List element.
   * @param input the List of E to be partitioned.
   * @param predicate the predicate to partition the list.
   * @return a pair of Lists partitioned by the predicate.
   */
  public static <E> List<List<E>> partition(List<E> input, Predicate<E> predicate) {
    List<E> truePartition = new ArrayList<E>();
    List<E> falsePartition = new ArrayList<E>();
    for (E e : input) {
      if (predicate.evaluate(e)) {
        truePartition.add(e);
      } else {
        falsePartition.add(e);
      }
    }
    List<List<E>> partitioned = new ArrayList<List<E>>();
    partitioned.add(truePartition);
    partitioned.add(falsePartition);
    return partitioned;
  }
  
  /**
   * Partitions an input list into multiple partitions based on a partitioning
   * transformer. The transformer takes an element and returns a 0-based integer
   * representing the list into which this element will be placed. 
   * @param <E> the type of the List element.
   * @param input the List of E.
   * @param partitioner a partitioning transformer.
   * @return the partitioned List of List<E>.
   */
  public static <E> List<List<E>> partition(List<E> input, 
      Transformer<E,Integer> partitioner) {
    Map<Integer,List<E>> parts = new HashMap<Integer,List<E>>();
    for (E e : input) {
      Integer partId = partitioner.transform(e);
      List<E> part;
      if (parts.containsKey(partId)) {
        part = parts.get(partId);
      } else {
        part = new ArrayList<E>();
      }
      part.add(e);
      parts.put(partId, part);
    }
    List<List<E>> partitions = new ArrayList<List<E>>();
    for (Integer partId : parts.keySet()) {
      if (parts.containsKey(partId)) {
        partitions.add(parts.get(partId));
      } else {
        partitions.add(new ArrayList<E>());
      }
    }
    return partitions;
  }

  /**
   * Flattens a List of Lists into a single flat list. Empty Lists and 
   * null elements are ignored during flattening. For example, the List
   * of Lists {{1,2,3},{4,5,6}}.flatten yields {1,2,3,4,5,6}.
   * @param <E> the type of the List element.
   * @param inputs the list of list of E.
   * @return a flattened list of E.
   */
  public static <E> List<E> flatten(List<List<E>> inputs) {
    List<E> flattened = new ArrayList<E>();
    for (List<E> le : inputs) {
      if (le == null || le.size() == 0) {
        continue;
      }
      for (E e : le) {
        if (e == null) {
          continue;
        }
        flattened.add(e);
      }
    }
    return flattened;
  }

  /**
   * Joins individual elements of a pair of lists. Result is a list of
   * pairs of elements. If the sizes of the two input lists are different,
   * then the size of the zipped list is the smaller of the two sizes of
   * the input. For example, {{1,2,3},{4,5}}.zip will yield {{1,4},{2,5}}.
   * @param <E> the type of the list element.
   * @param left the first list of E.
   * @param right the second list of E.
   * @return the zipped list of pairs of E.
   */
  public static <E> List<List<E>> zip(List<E> left, List<E> right) {
    List<List<E>> zipped = new ArrayList<List<E>>();
    int leftSize = left.size();
    int rightSize = right.size();
    int minSize = Math.min(leftSize, rightSize);
    for (int i = 0; i < minSize; i++) {
      List<E> listE = new ArrayList<E>();
      listE.add(left.get(i));
      listE.add(right.get(i));
      zipped.add(listE);
    }
    return zipped;
  }
  
  /**
   * Inverse of the zip operation. Given a List of pairs of elements, this
   * method breaks it up into two lists of E. For example, {{1,4},{2,5}.unzip
   * will yield {{1,2},{4,5}}.
   * @param <E> the type of the List element.
   * @param inputs the List of pairs of E.
   * @return two lists of E.
   */
  public static <E> List<List<E>> unzip(List<List<E>> inputs) {
    List<E> left = new ArrayList<E>();
    List<E> right = new ArrayList<E>();
    for (List<E> le : inputs) {
      left.add(le.get(0));
      right.add(le.get(1));
    }
    List<List<E>> unzipped = 
      new ArrayList<List<E>>();
    unzipped.add(left);
    unzipped.add(right);
    return unzipped;
  }
  
  /**
   * A combination of map and flatten. Each element of the input list is 
   * transformed into a List of E by the transformer, and flattened out into
   * the output list. For example, {1,2}.flatMap(x => List(x, x * 10))
   * will yield {1,10,2,20}.
   * @param <E> the type of the list element.
   * @param input the list of E.
   * @param transformer transforms each element to a List of elements.
   * @return the flattened mapped list.
   */
  public static <E> List<E> flatMap(List<E> input,
      Transformer<E,List<E>> transformer) {
    List<E> flattened = new ArrayList<E>();
    for (E e : input) {
      List<E> transformed = transformer.transform(e);
      if (transformed != null && transformed.size() > 0) {
        flattened.addAll(transformed); // nulls are allowed
      }
    }
    return flattened;
  }

  /**
   * Merges a List of Lists into a single list based on the sequencer. The
   * sequencer is a special transformer that takes an integer representing
   * the current list index (0-based) and returns the next list index to 
   * pick the next element from. It also needs to return the first list 
   * index when passed in an argument of -1.
   * @param <E> the type of the list element.
   * @param inputs the List of Lists of E.
   * @param sequencer a specialized transformer specifying merge sequence.
   * @return the merged list.
   */
  public static <E> List<E> merge(List<List<E>> inputs, 
      Transformer<Integer,Integer> sequencer) {
    int[] sizes = new int[inputs.size()];
    for (int i = 0; i < sizes.length; i++) {
      sizes[i] = inputs.get(i).size();
    }
    int[] current = new int[inputs.size()];
    List<E> merged = new ArrayList<E>();
    int lid = sequencer.transform(-1);
    for (;;) {
      if (lid > inputs.size()) {
        throw new IllegalArgumentException("List out of bounds");
      }
      if (current[lid] < sizes[lid]) {
        merged.add(inputs.get(lid).get(current[lid]));
      }
      current[lid]++;
      boolean alldone = true;
      for (int i = 0; i < sizes.length; i++) {
        if (current[i] < sizes[i]) {
          alldone = false;
          break;
        }
      }
      if (alldone) { 
        break;
      }
      lid = sequencer.transform(lid);
    }
    return merged;
  }
  
  /**
   * Convenience method to pack two elements E into a List<E>.
   * @param <E> the type of element.
   * @param first the first element in the list.
   * @param second the second element in the list.
   * @return the List containing the pair of E.
   */
  private static <E> List<E> pack(E first, E second) {
    List<E> packed = new ArrayList<E>();
    packed.add(first);
    packed.add(second);
    return packed;
  }
}

For each operation or set of operations, I wrote some JUnit code to show the usage in Java. I also tried the function out in Scala to make sure that I don't change the behavior. I describe these in the sections below.

filter

Filter extracts the elements for which the given predicate is true. Here is some Scala code to extract the even numbers from a List of Integers.

1
2
3
scala> val list = List(1,2,3,4,5,6,7,8,9,10)
scala> list.filter(x => (x % 2 == 0))
res0: List[Int] = List(2, 4, 6, 8, 10)

You can call this method in Java like this. Notice that although the Predicate is defined as an anonymous function, there is no reason you cannot define it as a variable and pass it in, or even make a class out of it (although the last approach may lead to a very large number of tiny classes in your application, and possibly make your code harder to maintain).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
  @Test
  public void testFilter() throws Exception {
    List<Integer> list = 
      Arrays.asList(new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<Integer> evens = ListUtils.filter(list, 
      new Predicate<Integer>() {
        public boolean evaluate(Integer i) {
          return (i % 2 == 0);
        }
    });
    System.out.println(">>> evens = " + evens);
  }

which produces the following output:

1
>>> evens = [2, 4, 6, 8, 10]

map

Map applies a transformation to each element of the input list, resulting in either a list of the same or different type. We show two usages of map in Scala below.

1
2
3
4
5
scala> val list = List(1,2,3,4,5)
scala> val squares = list.map(x => x * x) 
squares: List[Int] = List(1, 4, 9, 16, 25)
scala> val strings = list.map(x => "s" + x)     
strings: List[java.lang.String] = List(s1, s2, s3, s4, s5)

Using the ListUtils.map method in Java is a little more verbose, but otherwise quite similar, as shown below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
  @Test
  public void testMap() throws Exception {
    List<Integer> list = 
      Arrays.asList(new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<Integer> squares = ListUtils.map(list, 
      new Transformer<Integer,Integer>() {
        public Integer transform(Integer i) {
          return i * i;
        }
    });
    System.out.println(">>> squares = " + squares);
    List<String> strings = ListUtils.map(list, 
      new Transformer<Integer,String>() {
        public String transform(Integer i) {
          return "s" + i;
        }
    });
    System.out.println(">>> strings = " + strings);
  }

and produces the following output:

1
2
>>> squares = [1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
>>> strings = [s1, s2, s3, s4, s5, s6, s7, s8, s9, s10]

reduceLeft, reduceRight, foldLeft and foldRight

All these functions aim to apply a transformation to combine the elements of a list into a single element. The left and right suffixes to the method names indicate the direction of the reduction - reduceLeft reduces starting from the left and reduceRight starts from the right. The corresponding fold methods are the same as reduce, but they allow you to specify an initial value to the reduction. Here are some Scala examples:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
scala> val ints = List(1,2,3,4,5,6,7,8,9,10)
scala> ints.reduceLeft((x,y) => (x + y))
res0: Int = 55
scala> ints.reduceRight((x,y) => (x + y))
res1: Int = 55
scala> ints.foldLeft(0)((x,y) => (x+y))
res2: Int = 55
scala> ints.foldRight(0)((x,y) => (x+y))
res3: Int = 55
scala> ints.reduceLeft((x,y) => if (x > y) x else y)
res4: Int = 10
scala> ints.reduceRight((x,y) => if (x > y) x else y)
res5: Int = 10
scala> ints.foldLeft(0)((x,y) => if (x > y) x else y) 
res6: Int = 10
scala> ints.foldRight(0)((x,y) => if (x > y) x else y)
res7: Int = 10
scala> val strs = ints.map(x => "s" + x)
scala> strs.reduceLeft((x,y) => (x + "," + y))
res8: java.lang.String = s1,s2,s3,s4,s5,s6,s7,s8,s9,s10
scala> strs.reduceRight((x,y) => (x + "," + y))
res9: java.lang.String = s1,s2,s3,s4,s5,s6,s7,s8,s9,s10
scala> strs.foldLeft("")((x,y) => (x + "," + y))
res10: java.lang.String = ,s1,s2,s3,s4,s5,s6,s7,s8,s9,s10
scala> strs.foldRight("")((x,y) => (x + "," + y))
res11: java.lang.String = s1,s2,s3,s4,s5,s6,s7,s8,s9,s10,

Notice that in most cases (at least in our examples), there is no difference between the left and right reduces (and folds) - this is true when the operation is associative. The same operations using ListUtils is shown below:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  @Test
  public void testReduceLeft() throws Exception {
    List<Integer> list = Arrays.asList(
      new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<String> strings = ListUtils.map(list, 
      new Transformer<Integer,String>() {
        public String transform(Integer i) {
          return "s" + i;
        }
    });
    Integer sum = ListUtils.reduceLeft(list, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return first + second;
        }
    });
    System.out.println(">>> sum from reduceLeft = " + sum);
    Integer max = ListUtils.reduceLeft(list, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return (first > second ? first : second);
        }
    });
    System.out.println(">>> max from reduceLeft = " + max);
    String strcat = ListUtils.reduceLeft(strings, 
      new Transformer<List<String>,String>() {
        public String transform(List<String> args) {
          return args.get(0) + "," + args.get(1);
        }
    });
    System.out.println(">>> strcat from reduceLeft = " + strcat);
  }
  
  @Test
  public void testReduceRight() throws Exception {
    List<Integer> list = Arrays.asList(
      new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<String> strings = ListUtils.map(list, 
      new Transformer<Integer,String>() {
        public String transform(Integer i) {
          return "s" + i;
        }
    });
    Integer sum = ListUtils.reduceRight(list, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return first + second;
        }
    });
    System.out.println(">>> sum from reduceRight = " + sum);
    Integer max = ListUtils.reduceRight(list, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return (first > second ? first : second);
        }
    });
    System.out.println(">>> max from reduceRight = " + max);
    String strcat = ListUtils.reduceRight(strings, 
      new Transformer<List<String>,String>() {
        public String transform(List<String> args) {
          return args.get(0) + "," + args.get(1);
        }
    });
    System.out.println(">>> strcat from reduceRight = " + strcat);
  }

  @Test
  public void testFoldLeft() throws Exception {
    List<Integer> list = Arrays.asList(
      new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<String> strings = ListUtils.map(list, 
      new Transformer<Integer,String>() {
        public String transform(Integer i) {
          return "s" + i;
        }
    });
    Integer sum = ListUtils.foldLeft(list, 0, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return first + second;
        }
    });
    System.out.println(">>> sum from foldLeft = " + sum);
    Integer max = ListUtils.foldLeft(list, 0, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return (first > second ? first : second);
        }
    });
    System.out.println(">>> max from foldLeft = " + max);
    String strcat = ListUtils.foldLeft(strings, "", 
      new Transformer<List<String>,String>() {
        public String transform(List<String> args) {
          return args.get(0) + "," + args.get(1);
        }
    });
    System.out.println(">>> strcat from foldLeft = " + strcat);
  }

  @Test
  public void testFoldRight() throws Exception {
    List<Integer> list = Arrays.asList(
      new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<String> strings = ListUtils.map(list, 
      new Transformer<Integer,String>() {
        public String transform(Integer i) {
          return "s" + i;
        }
    });
    Integer sum = ListUtils.foldRight(list, 0, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return first + second;
        }
    });
    System.out.println(">>> sum from foldRight = " + sum);
    Integer max = ListUtils.foldRight(list, 0, 
      new Transformer<List<Integer>,Integer>() {
        public Integer transform(List<Integer> args) {
          Integer first = args.get(0);
          Integer second = args.get(1);
          return (first > second ? first : second);
        }
    });
    System.out.println(">>> max from foldRight = " + max);
    String strcat = ListUtils.foldRight(strings, "", 
      new Transformer<List<String>,String>() {
        public String transform(List<String> args) {
          return args.get(0) + "," + args.get(1);
        }
    });
    System.out.println(">>> strcat from foldRight = " + strcat);
  }

The above four tests produce the following outputs.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
>>> sum from reduceLeft = 55
>>> max from reduceLeft = 10
>>> strcat from reduceLeft = s1,s2,s3,s4,s5,s6,s7,s8,s9,s10
>>> sum from reduceRight = 55
>>> max from reduceRight = 10
>>> strcat from reduceRight = s1,s2,s3,s4,s5,s6,s7,s8,s9,s10
>>> sum from foldLeft = 55
>>> max from foldLeft = 10
>>> strcat from foldLeft = ,s1,s2,s3,s4,s5,s6,s7,s8,s9,s10
>>> sum from foldRight = 55
>>> max from foldRight = 10
>>> strcat from foldRight = s1,s2,s3,s4,s5,s6,s7,s8,s9,s10,

partition

The partition method in Scala uses a predicate to split the list into two parts and returns them. In addition, ListUtils has an extra method that allows you to split a List into multiple partitions based on a partitioning transformer. Heres how you would do this in Scala.

1
2
3
scala> val list = List(1,2,3,4,5)
scala> list.partition(x => (x % 2 == 0))
res0: (List[Int], List[Int]) = (List(2, 4),List(1, 3, 5))

And this JUnit test shows how to use the two partition methods with ListUtils.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  @Test
  public void testPartition() throws Exception {
    List<Integer> list = Arrays.asList(
      new Integer[] {1,2,3,4,5,6,7,8,9,10});
    List<List<Integer>> oddevens = ListUtils.partition(list, 
      new Predicate<Integer>() {
        public boolean evaluate(Integer i) {
          return (i % 2 == 1);
        }
    });
    List<List<Integer>> evenodds = ListUtils.partition(list, 
      new Predicate<Integer>() {
        public boolean evaluate(Integer i) {
          return (i % 2 == 0);
        }
    });
    System.out.println(">>> partition ints to odds and evens = " +
      oddevens.get(0) + " and " + oddevens.get(1));
    System.out.println(">>> partition ints to evens and odds = " +
      evenodds.get(0) + " and " + evenodds.get(1));
    List<List<Integer>> multiparts = ListUtils.partition(list, 
      new Transformer<Integer,Integer>() {
        public Integer transform(Integer i) {
          if (i <= 5) {
            if (i % 2 == 0) {
              return 0;
            } else {
              return 1;
            }
          } else {
           if (i % 2 == 0) {
             return 2;
           } else {
             return 3;
           }
          }
        }
    });
    System.out.println(">>> multipart partitioning ints = " +
      multiparts);
    List<String> strings = Arrays.asList(
      new String[] {"a1","a2","b1","b2","c1","c2"});
    List<List<String>> smultiparts = 
      ListUtils.partition(strings, new Transformer<String,Integer>() {
        public Integer transform(String s) {
          char firstchar = StringUtils.lowerCase(s).charAt(0);
          return firstchar - 'a';
        }
    });
    System.out.println(">>> multipart partitioning strings = " +
      smultiparts);
  }

and they produce the following output:

1
2
3
4
>>> partition ints to odds and evens = [1, 3, 5, 7, 9] and [2, 4, 6, 8, 10]
>>> partition ints to evens and odds = [2, 4, 6, 8, 10] and [1, 3, 5, 7, 9]
>>> multipart partitioning ints = [[2, 4], [1, 3, 5], [6, 8, 10], [7, 9]]
>>> multipart partitioning strings = [[a1, a2], [b1, b2], [c1, c2]]

forall, find and exists

The Scala forall method allows the caller to operate on each element of the list in some way. Its also called a List Comprehension (Python) or a Sequence Comprehension (Scala). CollectionUtils.forAllDo() provides the functionality of the forall method, so I did not implement it in ListUtils.

The same goes for find and exists methods - these are both read-only methods in CollectionUtils, so no need to implement them in ListUtils.

zip, unzip and flatten

These are some of Scala's methods that I thought would be useful, so I went ahead and implemented it in ListUtils.

The zip method "zips up" the elements of two input lists. If the sizes of the two input lists are different, the size of the zipped list is the smaller of the two. The unzip method is the inverse of the zip method. The flatten method takes a list of lists and flattens it out, ignoring empty lists but not null elements.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
scala> val list1 = List(1,2,3,4,5)
scala> val list2 = List(10,20,30)
scala> val zipped = list1.zip(list2)
zipped: List[(Int, Int)] = List((1,10), (2,20), (3,30))
scala> val unzipped = List.unzip(zipped)
unzipped: (List[Int], List[Int]) = (List(1, 2, 3),List(10, 20, 30))
scala> val lol = List(List(1,2,3),List(10,20,30))
scala> List.flatten(lol)
res0: List[Int] = List(1, 2, 3, 10, 20, 30)
scala> val lol2 = List(List(1,2),List(),List(None,10,9,8))
scala> List.flatten(lol2)
res1: List[Any] = List(1, 2, None, 10, 9, 8)

The code snippet below shows how to use these three methods in ListUtils.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
  @Test
  public void testFlatten() throws Exception {
    List<Integer> list0 = Arrays.asList(new Integer[] {1,2,3});
    Transformer<Integer,Integer> mult10 = 
      new Transformer<Integer,Integer>() {
        public Integer transform(Integer i) {
          return i * 10;
        }
    };
    List<Integer> list10 = ListUtils.map(list0, mult10);
    List<List<Integer>> zipped = ListUtils.zip(list0, list10);
    System.out.println(">>> zipped = " + zipped);
    List<Integer> flattened = ListUtils.flatten(zipped);
    System.out.println(">>> flattened = " + flattened);
    List<List<Integer>> unzipped = ListUtils.unzip(zipped);
    System.out.println(">>> unzipped = " + unzipped);
    List<Integer> list10Le20 = ListUtils.filter(list10, 
      new Predicate<Integer>() {
        public boolean evaluate(Integer i) {
          return (i <= 20);
        }
    });
    List<List<Integer>> zippedLe20 = ListUtils.zip(
      list0, list10Le20);
    System.out.println(">>> zipped (le 20) = " + zippedLe20);
    List<Integer> flattenedLe20 = ListUtils.flatten(zippedLe20);
    System.out.println(">>> flattened (le 20) = " + flattenedLe20);
    List<List<Integer>> unzippedLe20 = 
      ListUtils.unzip(zippedLe20);
    System.out.println(">>> unzipped (le 20) = " + unzippedLe20);
  }

which produce the following output:

1
2
3
4
5
6
>>> zipped = [[1, 10], [2, 20], [3, 30]]
>>> flattened = [1, 10, 2, 20, 3, 30]
>>> unzipped = [[1, 2, 3], [10, 20, 30]]
>>> zipped (le 20) = [[1, 10], [2, 20]]
>>> flattened (le 20) = [1, 10, 2, 20]
>>> unzipped (le 20) = [[1, 2], [10, 20]]

flatMap

The flatMap method is a combination of map to generate zero or more elements for each input element, then flatten them out. I have found this useful for navigating hierarchies and then flattening the nodes out. Here is a simple Scala example.

1
2
3
scala> val list = List(1,2,3)
scala> list.flatMap(x => List(x, x*10, x*100))   
res0: List[Int] = List(1, 10, 100, 2, 20, 200, 3, 30, 300)

And the corresponding example in Java using ListUtils.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
  @Test
  public void testFlatMap() throws Exception {
    List<Integer> list = Arrays.asList(new Integer[] {1,2,3});
    List<Integer> flatmapped = ListUtils.flatMap(list, 
      new Transformer<Integer,List<Integer>>() {
        public List<Integer> transform(Integer i) {
          List<Integer> transformed = new ArrayList<Integer>();
          transformed.add(i);
          transformed.add(i * 10);
          transformed.add(i * 100);
          return transformed;
        }
    });
    System.out.println(">>> flatmapped = " + flatmapped);
  }

This code produces the following output:

1
>>> flatmapped = [1, 10, 100, 2, 20, 200, 3, 30, 300]

merge

The final method (not in Scala, as far as I know) is the merge method. I wrote this because I needed this. Essentially, the merge method attempts to merge multiple lists using a sequence rule, modelled here as a Transformer. Here is how one may use it.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  @Test
  public void testMerge() throws Exception {
    List<String> list1 = 
      Arrays.asList(new String[] {"a1", "a2", "a3"});
    List<String> list2 = Arrays.asList(new String[] {"b1"});
    List<String> list3 = Arrays.asList(new String[] {"c1", "c2"});
    List<String> list4 = Arrays.asList(new String[] {});
    List<List<String>> inputs = 
      new ArrayList<List<String>>();
    inputs.add(list1);
    inputs.add(list2);
    inputs.add(list3);
    inputs.add(list4);
    List<String> merged = ListUtils.merge(inputs, 
      new Transformer<Integer,Integer>() {
        public Integer transform(Integer i) {
          switch (i) { // sequence is {0,2,1,3}
            case -1:
              return 0;
            case 0:
              return 2;
            case 1:
              return 3; 
            case 2:
              return 1;
            case 3:
            default: // never happen
              return 0;
          }
        }
    });
    System.out.println(">>> merged = " + merged);
  }

The merge produces the following output:

1
>>> merged = [a1, c1, b1, a2, c2, a3]

I hope this stuff was useful. I've been using the functor objects from commons-collections (the larvalabs version) for a while now, but its only recently, after learning about them in Scala (and later Python), that I have been thinking how much cleaner my application would be with logical one-liners instead of for-loops. I think that judicious use of this feature (i.e. resisting the temptation of the golden hammer :-)) can result in code that is easier and more fun to write, as well as more readable and hence easier to maintain.

2 comments (moderated to prevent spam):

Anonymous said...

I was looking for a partition that uses a predicate in guava (Google Collections) and found this post. I made a modification to use a Map> to hold the result. This make doing result.get(true) a little more clear than result.get(0) even though yours is certainly more lisp-esque (list of lists). Anyways here a modified version of yours.

public static Map> partition(List input, Predicate predicate) {
Map> result =
ImmutableMap.>builder()
.put(true, new ArrayList())
.put(false, new ArrayList())
.build();
for (E e : input) {
result.get(predicate.apply(e)).add(e);
}
return result;
}

Sujit Pal said...

Thanks Anonymous. Guava seems to have some nice methods compared to larvalabs collections, have been meaning to take a look at it for a while now...