This repository has been archived on 2022-06-22. You can view files and clone it, but cannot push or open issues or pull requests.
usaco-guide/content/3_Silver/Sorting_Custom.mdx
2020-07-17 18:29:45 -07:00

574 lines
15 KiB
Text

---
id: sorting-custom
title: "Sorting with Custom Comparators"
frequency: 3
author: Darren Yao, Siyong Huang, Michael Cao, Benjamin Qi
prerequisites:
- pairs-tuples
description: "Both Java and C++ have built-in functions for sorting. However, if we use custom objects, or if we want to sort elements in a different order, then we'll need to use a custom comparator."
---
import { Problem } from "../models";
export const metadata = {
problems: {
sample: [
new Problem("Silver", "Wormhole Sort", "992", "Normal", false, [], ""),
],
general: [
new Problem("CSES", "Restaurant Customers", "1619", "Easy", false, [], "Just sort endpoints of intervals and process in order."),
new Problem("Silver", "Lifeguards", "786", "Easy", false, [], "Similar to above."),
new Problem("Silver", "Rental Service", "787", "Easy", false, [], ""),
new Problem("Silver", "Mountains", "896", "Easy", false, [], ""),
new Problem("Silver", "Mooyo Mooyo", "860", "Easy", false, [], "Not a sorting problem, but you can use sorting to simulate gravity. - Write a custom comparator which puts zeroes at the front and use `stable_sort` to keep the relative order of other elements the same."),
new Problem("Silver", "Meetings", "967", "Very Hard", false, [], ""),
],
}
};
<Resources>
<Resource source="IUSACO" title="8 - Sorting & Comparators"></Resource>
</Resources>
<br />
## Example: Wormhole Sort
<Problems problems={metadata.problems.sample} />
There are multiple ways to solve this problem. We won't discuss the full solution here, but all of them start by sorting the edges in nondecreasing order of weight. For example, the sample contains the following edges:
```
1 2 9
1 3 7
2 3 10
2 4 3
```
After sorting, it should look like
```
2 4 3
1 3 7
1 2 9
2 3 10
```
With C++, the easiest method is to use nested pairs or a `vector` of `array<int,3>` or `vector<int>`.
```cpp
#include <bits/stdc++.h>
using namespace std;
#define f first
#define s second
int main() {
int M = 4;
vector<pair<int,pair<int,int>>> v;
for (int i = 0; i < M; ++i) {
int a,b,w; cin >> a >> b >> w;
v.push_back({w,{a,b}});
}
sort(begin(v),end(v));
for (auto e: v) cout << e.s.f << " " << e.s.s << " " << e.f << "\n";
}
```
```cpp
int main() {
int M = 4;
vector<array<int,3>> v; // or vector<vector<int>>
for (int i = 0; i < M; ++i) {
int a,b,w; cin >> a >> b >> w;
v.push_back({w,a,b});
}
sort(begin(v),end(v));
for (auto e: v) cout << e[1] << " " << e[2] << " " << e[0] << "\n";
}
```
In Python, we can use a list of lists.
But in Java, we can't immediately sort an arraylist of arraylists. What should we do?
- If we only stored the edge weights and sorted them, we would have a sorted list of edge weights, but it would be impossible to tell which weights corresponded to which edges.
- However, if we create a **class** (or struct) representing the edges and define a **custom comparator** to sort them by weight, we can sort the edges in ascending order while also keeping track of their endpoints.
## Classes
First, we need to define a **class** that represents what we want to sort. In our example we will define a class `Edge` that contains the two endpoints of the edge and the weight.
<LanguageSection>
<CPPSection>
```cpp
#include <bits/stdc++.h>
using namespace std;
struct Edge {
int a,b,w;
};
int main() {
int M = 4;
vector<Edge> v;
for (int i = 0; i < M; ++i) {
int a,b,w; cin >> a >> b >> w;
v.push_back({a,b,w});
}
for (Edge e: v) cout << e.a << " " << e.b << " " << e.w << "\n";
}
```
</CPPSection>
<JavaSection>
```java
import java.util.*;
public class Sol {
static class Edge {
int a,b,w;
public Edge(int _a, int _b, int _w) { a = _a; b = _b; w = _w; }
}
public static void main(String[] args) {
int M = 4;
Scanner in = new Scanner(System.in);
ArrayList<Edge> v = new ArrayList<Edge>();
for (int i = 0; i < M; ++i) {
int a = in.nextInt();
int b = in.nextInt();
int w = in.nextInt();
v.add(new Edge(a,b,w));
}
for (Edge e: v){
System.out.print(e.a);
System.out.print(' ');
System.out.print(e.b);
System.out.print(' ');
System.out.println(e.w);
}
}
}
```
</JavaSection>
<PySection>
```py
class Edge:
def __init__(self, a, b, w):
self.a = a
self.b = b
self.w = w
v = []
M = 4
for i in range(M):
a,b,w = map(int,input().split())
v.append(Edge(a,b,w))
for e in v:
print(e.a,e.b,e.w)
```
</PySection>
</LanguageSection>
## Comparators
Normally, sorting functions rely on moving objects with a lower value in front of objects with a higher value if sorting in ascending order, and vice versa if in descending order. This is done through comparing two objects at a time.
<LanguageSection>
<CPPSection>
<Resources>
<Resource source="CPH" title="3.2 - User-Defined Structs, Comparison Functions" starred></Resource>
</Resources>
What a comparator does is compare two objects as follows, based on our comparison criteria:
- If object $x$ is less than object $y$, return `true`
- If object $x$ is greater than or equal to object $y$, return `false`
Essentially, the comparator determines whether object $x$ belongs to the left of object $y$ in a sorted ordering. A comparator **must** return false for two identical objects (not doing so results in undefined behavior and potentially a runtime error).
In addition to returning the correct answer, comparators should also satisfy the following conditions:
- The function must be consistent with respect to reversing the order of the arguments: if $x \neq y$ and `compare(x, y)`is `true`, then `compare(y, x)` should be `false` and vice versa.
- The function must be transitive. If `compare(x, y)` is true and `compare(y, z)` is true, then `compare(x, z)` should also be true. If the first two compare functions both return `false`, the third must also return `false`.
### Method 1: Overloading the Less Than Operator
This is the easiest to implement. However, it only works for objects (not primitives) and it doesn't allow you to define multiple ways to compare the same type of class.
In the context of Wormhole Sort (note the use of
[const T&](https://stackoverflow.com/questions/11805322/why-should-i-use-const-t-instead-of-const-t-or-t)):
```cpp
#include <bits/stdc++.h>
using namespace std;
struct Edge {
int a,b,w;
bool operator<(const Edge& y) { return w < y.w; }
};
int main() {
int M = 4;
vector<Edge> v;
for (int i = 0; i < M; ++i) {
int a,b,w; cin >> a >> b >> w;
v.push_back({a,b,w});
}
sort(begin(v),end(v));
for (Edge e: v) cout << e.a << " " << e.b << " " << e.w << "\n";
}
```
We can also overload the operator outside of the class:
```cpp
struct Edge {
int a,b,w;
};
bool operator<(const Edge& x, const Edge& y) { return x.w < y.w; }
```
or within it using [friend](https://www.geeksforgeeks.org/friend-class-function-cpp/):
```cpp
struct Edge {
int a,b,w;
friend bool operator<(const Edge& x, const Edge& y) { return x.w < y.w; }
};
```
### Method 2: Comparison Function
This works for both objects and primitives, and you can declare many different comparators for the same object.
```cpp
#include <bits/stdc++.h>
using namespace std;
struct Edge {
int a,b,w;
};
bool cmp(const Edge& x, const Edge& y) { return x.w < y.w; }
int main() {
int M = 4;
vector<Edge> v;
for (int i = 0; i < M; ++i) {
int a,b,w; cin >> a >> b >> w;
v.push_back({a,b,w});
}
sort(begin(v),end(v),cmp);
for (Edge e: v) cout << e.a << " " << e.b << " " << e.w << "\n";
}
```
We can also use [lambda expressions](https://www.geeksforgeeks.org/lambda-expression-in-c/) in C++11 or above.
```cpp
sort(begin(v),end(v),[](const Edge& x, const Edge& y) { return x.w < y.w; });
```
</CPPSection>
<JavaSection>
What a `Comparator` does is compare two objects as follows, based on our comparison criteria:
- If object $x$ is less than object $y$, return a negative integer.
- If object $x$ is greater than object $y$, return a positive integer.
- If object $x$ is equal to object $y$, return 0.
In addition to returning the correct number, comparators should also satisfy the following conditions:
- The function must be consistent with respect to reversing the order of the arguments: if `compare(x, y)` is positive, then `compare(y, x)` should be negative and vice versa.
- The function must be transitive. If `compare(x, y) > 0` and `compare(y, z) > 0`, then `compare(x, z) > 0`. Same applies if the compare functions return negative numbers.
- Equality must be consistent. If `compare(x, y) = 0`, then `compare(x, z)` and `compare(y, z)` must both be positive, both negative, or both zero. Note that they don't have to be equal, they just need to have the same sign.
Java has default functions for comparing `int`, `long`, `double` types. The `Integer.compare()`, `Long.compare()`, and `Double.compare()` functions take two arguments $x$ and $y$ and compare them as described above.
There are two ways of implementing this in Java: `Comparable`, and `Comparator`. They essentially serve the same purpose, but `Comparable` is generally easier and shorter to code. `Comparable` is a function implemented within the class containing the custom object, while `Comparator` is its own class.
### Method 1: Comparable
We'll need to put `implements Comparable<Edge>` into the heading of the class. Furthermore, we'll need to implement the `compareTo` method. Essentially, `compareTo(x)` is the `compare` function that we described above, with the object itself as the first argument, or `compare(self, x)`.
When using Comparable, we can just call `Arrays.sort(arr)` or `Collections.sort(list)` on the array or list as usual.
```java
import java.util.*;
public class Sol {
static class Edge implements Comparable<Edge> {
int a,b,w;
public Edge(int _a, int _b, int _w) { a = _a; b = _b; w = _w; }
public int compareTo(Edge y) { return Integer.compare(w,y.w); }
}
public static void main(String[] args) {
int M = 4;
Scanner in = new Scanner(System.in);
ArrayList<Edge> v = new ArrayList<Edge>();
for (int i = 0; i < M; ++i) {
int a = in.nextInt();
int b = in.nextInt();
int w = in.nextInt();
v.add(new Edge(a,b,w));
}
Collections.sort(v);
for (Edge e: v){
System.out.print(e.a);
System.out.print(' ');
System.out.print(e.b);
System.out.print(' ');
System.out.println(e.w);
}
}
}
```
### Method 2: Comparator
If instead we choose to use `Comparator`, we'll need to declare a second class that implements `Comparator<Edge>`:
```java
import java.util.*;
public class Sol {
static class Edge {
int a,b,w;
public Edge(int _a, int _b, int _w) { a = _a; b = _b; w = _w; }
}
static class Comp implements Comparator<Edge> {
public int compare(Edge a, Edge b) {
return Integer.compare(a.w, b.w);
}
}
public static void main(String[] args) {
int M = 4;
Scanner in = new Scanner(System.in);
ArrayList<Edge> v = new ArrayList<Edge>();
for (int i = 0; i < M; ++i) {
int a = in.nextInt();
int b = in.nextInt();
int w = in.nextInt();
v.add(new Edge(a,b,w));
}
Collections.sort(v, new Comp());
for (Edge e: v){
System.out.print(e.a);
System.out.print(' ');
System.out.print(e.b);
System.out.print(' ');
System.out.println(e.w);
}
}
}
```
When using `Comparator`, the syntax for using the built-in sorting function requires a second argument: `Arrays.sort(arr, new Comp())`, or `Collections.sort(list, new Comp())`.
</JavaSection>
<PySection>
### Defining Less Than Operator
<!-- Tested -->
```py
class Edge:
def __init__(self, a, b, w):
self.a = a
self.b = b
self.w = w
def __lt__(self, y): # lt means less than
return self.w < y.w
v = []
M = 4
for i in range(M):
a,b,w = map(int,input().split())
v.append(Edge(a,b,w))
v.sort()
for e in v:
print(e.a,e.b,e.w)
```
### Key Function
This method maps an object to another comparable datatype with which to be sorted. In this case we map edges to their weights.
```py
class Edge:
def __init__(self, a, b, w):
self.a = a
self.b = b
self.w = w
v = []
M = 4
for i in range(M):
a,b,w = map(int,input().split())
v.append(Edge(a,b,w))
v.sort(key=lambda x:x.w)
for e in v:
print(e.a,e.b,e.w)
```
### Comparison Function
A comparison function in Python must satisfy the same properties as a comparator in Java. Note that old-style cmp functions are [no longer supported](https://stackoverflow.com/questions/12749398/using-a-comparator-function-to-sort), so the comparison function must be converted into a key function with [`cmp_to_key`](https://docs.python.org/2/library/functools.html).
```py
from functools import cmp_to_key
class Edge:
def __init__(self, a, b, w):
self.a = a
self.b = b
self.w = w
v = []
M = 4
for i in range(M):
a,b,w = map(int,input().split())
v.append(Edge(a,b,w))
v.sort(key=cmp_to_key(lambda a,b: a.w-b.w))
for e in v:
print(e.a,e.b,e.w)
```
</PySection>
</LanguageSection>
## Variations
### Sorting in Decreasing Order of Weight
We can replace all occurrences of `x.w < y.w` with `x.w > y.w` in our C++ code. Similarly, we can replace all occurrences of `Integer.compare(x, y)` with `-Integer.compare(x, y)` in our Java code or negate the key in Python.
### Sorting by Two Criteria
Now, suppose we wanted to sort a list of `Edge`s in ascending order, primarily by weight and secondarily by first vertex (`a`). We can do this quite similarly to how we handled sorting by one criterion earlier. What the comparator function needs to do is to compare the weights if the weights are not equal, and otherwise compare first vertices.
<LanguageSection>
<CPPSection>
```cpp
struct Edge {
int a,b,w;
bool operator<(const Edge& y) {
if (w != y.w) return w < y.w;
return a < y.a;
}
};
```
</CPPSection>
<JavaSection>
```java
static class Edge implements Comparable<Edge> {
int a,b,w;
public Edge(int _a, int _b, int _w) { a = _a; b = _b; w = _w; }
public int compareTo(Edge y) {
if (w != y.w) return Integer.compare(w,y.w);
return Integer.compare(a,y.a);
}
}
```
</JavaSection>
<PySection>
```py
class Edge:
def __init__(self, a, b, w):
self.a = a
self.b = b
self.w = w
def __lt__(self, y): # lt means less than
return self.w < y.w if self.w != y.w else self.a < y.a
```
</PySection>
</LanguageSection>
Sorting by an arbitrary number of criteria is done similarly.
<LanguageSection>
<CPPSection>
</CPPSection>
<JavaSection>
With Java, we can implement a comparator for arrays of arbitrary length (although this might be more confusing than creating a separate class).
```java
import java.util.*;
public class Sol {
static class Comp implements Comparator<int[]> {
public int compare(int[] a, int[] b){
for (int i = 0; i < a.length; ++i)
if (a[i] != b[i]) return Integer.compare(a[i],b[i]);
return 0;
}
}
public static void main(String[] args) {
int M = 4;
Scanner in = new Scanner(System.in);
ArrayList<int[]> v = new ArrayList<int[]>();
for (int i = 0; i < M; ++i) {
int a = in.nextInt();
int b = in.nextInt();
int w = in.nextInt();
v.add(new int[]{w,a,b});
}
Collections.sort(v, new Comp());
for (int[] e: v){
System.out.print(e[1]);
System.out.print(' ');
System.out.print(e[2]);
System.out.print(' ');
System.out.println(e[0]);
}
}
}
```
</JavaSection>
<PySection>
</PySection>
</LanguageSection>
## Problems
<Problems problems={metadata.problems.general} />