# 最小堆与索引堆

2020/10/08 13:25

public interface Heap<T> {
int getSize();
boolean isEmpty();
void add(T t);

/**
* 取出最大元素(最大堆)或最小元素(最小堆)
* @return
*/
T extract();

/**
* 查看最大元素(最大堆)或最小元素(最小堆)
* @return
*/
T findHeap();

/**
* 取出堆中最大元素(最大堆)或最小元素(最小堆)
* 并且替换成元素t
* @param t
* @return
*/
T replace(T t);
}
public class MinHeap<T extends Comparable<T>> implements Heap<T> {
private List<T> data;

public MinHeap(int capcity) {
data = new ArrayList<>(capcity);
}

public MinHeap() {
data = new ArrayList<>();
}

/**
* 将一个数组转化为最小堆
* @param arr
*/
public MinHeap(T[] arr) {
data = new ArrayList<>();
data.addAll(Arrays.asList(arr));
for (int i = parent(getSize() - 1); i >= 0; i--) {
shiftDown(i);
}
}

@Override
public int getSize() {
return data.size();
}

@Override
public boolean isEmpty() {
return data.isEmpty();
}

@Override
public void add(T t) {
data.add(t);
shiftUp(getSize() - 1);
}

@Override
public T extract() {
T ret = findHeap();
data.set(0,data.get(getSize() - 1));
data.set(getSize() - 1,ret);
data.remove(getSize() - 1);
shiftDown(0);
return ret;
}

@Override
public T findHeap() {
if (isEmpty()) {
throw new IllegalArgumentException("堆为空，不能做此操作");
}
return data.get(0);
}

@Override
public T replace(T t) {
T ret = findHeap();
data.set(0,t);
shiftDown(0);
return ret;
}

/**
* 获取父节点的索引
* @param index
* @return
*/
private int parent(int index) {
if (index == 0) {
throw new IllegalArgumentException("索引0没有父节点");
}
return (index - 1) / 2;
}

private int leftChild(int index) {
return index * 2 + 1;
}

private int rightChild(int index) {
return index * 2 + 2;
}

private void shiftUp(int index) {
//如果当前节点不是根节点
//且父节点的值比当前节点的值要大
while (index > 0 &&
data.get(parent(index)).compareTo(data.get(index)) > 0) {
//交换父节点和当前节点的值
//并把当前节点的索引改为父节点的索引
swap(parent(index),index);
index = parent(index);
}
}

private void shiftDown(int index) {
while (leftChild(index) < getSize()) {
int left = leftChild(index);
int right = left + 1;
int min = left;
//比较左右孩子的值哪个更小
if (right < getSize()
&& data.get(right).compareTo(data.get(left)) < 0) {
min = rightChild(index);
}
//如果父节点的值比左右孩子的最小值还要小，略过
if (data.get(index).compareTo(data.get(min)) <= 0) {
break;
}
//将父节点的值与左右孩子的最小值交换
//并将父节点的索引改为左右孩子最小值的索引
swap(index,min);
index = min;
}
}

private void swap(int x,int y) {
if (x < 0 || x >= getSize() || y < 0 || y >= getSize()) {
throw new IllegalArgumentException("索引非法");
}
T temp = data.get(x);
data.set(x,data.get(y));
data.set(y,temp);
}
}

public class MinHeapMain {
public static void main(String[] args) {
int n = 1000000;
Heap<Integer> minHeap = new MinHeap<>();
Random random = new Random();
for (int i = 0;i < n;i++) {
minHeap.add(random.nextInt(Integer.MAX_VALUE));
}
int[] arr = new int[n];
for (int i = 0;i < n;i++) {
arr[i] = minHeap.extract();
}
for (int i = 1;i < n;i++) {
if (arr[i - 1] > arr[i]) {
throw new IllegalArgumentException("出错");
}
}
System.out.println("测试最小堆完成");
}
}

public class HeapSort<T extends Comparable<T>> implements MSTResult<T>{
private List<T> res = new ArrayList<>();

public HeapSort(T arr[]) {
Heap<T> heap = new MinHeap<>(arr);
for (int i = 0; i < arr.length; i++) {
res.add(heap.extract());
}
}

@Override
public List<T> result() {
return res;
}

public static void main(String[] args) {
int n = 10;
Integer[] arr = new Integer[n];
Random random = new Random();
for (int i = 0; i < n; i++) {
arr[i] = random.nextInt(n * 1000);
}
MSTResult<Integer> heapSort = new HeapSort<>(arr);
System.out.println(heapSort.result());
}
}

[1404, 1897, 2300, 2614, 3043, 4590, 6088, 6256, 9538, 9910]

public interface IndexHeap<T> {
int getSize();
boolean isEmpty();
void add(int index,T t);
void add(T t);

/**
* 取出最大元素(最大堆)或最小元素(最小堆)
* @return
*/
T extract();

/**
* 取出最大元素(最大堆)或最小元素(最小堆)的索引
* @return
*/
int extractIndex();

/**
* 获取一个索引的元素
* @param index
* @return
*/
T get(int index);

/**
* 修改一个索引的元素为新的元素
* @param index
* @param t
*/
void change(int index,T t);

/**
* 查看最大元素(最大堆)或最小元素(最小堆)
* @return
*/
T findHeap();
}

@Getter
public class IndexMinHeap<T extends Comparable<T>> implements IndexHeap<T> {
private List<T> data;
private List<Integer> indexes;

public IndexMinHeap(int capcity) {
data = new ArrayList<>(capcity);
indexes = new ArrayList<>(capcity);
}

public IndexMinHeap() {
data = new ArrayList<>();
indexes = new ArrayList<>();
}

public IndexMinHeap(T[] arr) {
this(arr.length);
data.addAll(Arrays.asList(arr));
for (int i = 0; i < data.size(); i++) {
indexes.add(i);
}
for (int i = parent(getSize() - 1); i >= 0; i--) {
shiftDown(i);
}
}

@Override
public int getSize() {
return indexes.size();
}

@Override
public boolean isEmpty() {
return indexes.isEmpty();
}

@Override
public void add(int index,T t) {
data.add(index,t);
indexes.add(index);
shiftUp(getSize() - 1);
}

@Override
public void add(T t) {
int index = data.size();
data.add(t);
indexes.add(index);
shiftUp(getSize() - 1);
}

@Override
public T extract() {
T ret = findHeap();
swap(0,getSize() - 1);
indexes.remove(getSize() - 1);
shiftDown(0);
return ret;
}

@Override
public int extractIndex() {
int ret = indexes.get(0);
swap(0,getSize() - 1);
indexes.remove(getSize() - 1);
shiftDown(0);
return ret;
}

@Override
public T get(int index) {
if (!indexes.contains(index)) {
throw new IllegalArgumentException("索引不存在");
}
return data.get(index);
}

@Override
public void change(int index, T t) {
data.set(index,t);
for (int i = 0; i < getSize(); i++) {
if (indexes.get(i) == index) {
shiftUp(i);
shiftDown(i);
}
}
}

@Override
public T findHeap() {
if (isEmpty()) {
throw new IllegalArgumentException("堆为空，不能做此操作");
}
return data.get(indexes.get(0));
}

/**
* 获取父节点的索引
* @param index
* @return
*/
private int parent(int index) {
if (index == 0) {
throw new IllegalArgumentException("索引0没有父节点");
}
return (index - 1) / 2;
}

private int leftChild(int index) {
return index * 2 + 1;
}

private int rightChild(int index) {
return index * 2 + 2;
}

private void shiftUp(int index) {
//如果当前节点不是根节点
//且父节点的值比当前节点的值要大
while (index > 0 &&
data.get(indexes.get(parent(index))).compareTo(data.get(indexes.get(index))) > 0) {
//交换父节点和当前节点的值
//并把当前节点的索引改为父节点的索引
swap(parent(index),index);
index = parent(index);
}
}

private void shiftDown(int index) {
while (leftChild(index) < getSize()) {
int left = leftChild(index);
int right = left + 1;
int min = left;
//比较左右孩子的值哪个更小
if (right < getSize()
&& data.get(indexes.get(right)).compareTo(data.get(indexes.get(left))) < 0) {
min = rightChild(index);
}

//如果父节点的值比左右孩子的最小值还要小，略过
if (data.get(indexes.get(index)).compareTo(data.get(indexes.get(min))) <= 0) {
break;
}
//将父节点的值与左右孩子的最小值交换
//并将父节点的索引改为左右孩子最小值的索引
swap(index, min);
index = min;
}
}

private void swap(int x,int y) {
if (x < 0 || x >= getSize() || y < 0 || y >= getSize()) {
throw new IllegalArgumentException("索引非法");
}
int temp = indexes.get(x);
indexes.set(x,indexes.get(y));
indexes.set(y,temp);
}
}

public class IndexMinHeapMain {
public static void main(String[] args) {
int n = 10;
IndexHeap<Integer> minHeap = new IndexMinHeap<>(n);
Random random = new Random();
for (int i = 0;i < n;i++) {
minHeap.add(random.nextInt(n * 1000));
}
System.out.println(((IndexMinHeap)minHeap).getData());
Integer[] arr = new Integer[n];
for (int i = 0; i < n; i++) {
arr[i] = minHeap.extract();
}
for (int i = 1;i < n;i++) {
if (arr[i - 1] > arr[i]) {
throw new IllegalArgumentException("出错");
}
}
for (int i = 0; i < n; i++) {
System.out.print(arr[i] + " ");
}
Integer[] arr1 = new Integer[n];
for (int i = 0; i < n; i++) {
arr1[i] = random.nextInt(n * 1000);
}
System.out.println();
System.out.println("测试最小堆完成");
IndexHeap<Integer> minHeap1 = new IndexMinHeap<>(arr1);
System.out.println(((IndexMinHeap)minHeap1).getData());
System.out.println(((IndexMinHeap)minHeap1).getIndexes());
Integer[] arr2 = new Integer[n];
for (int i = 0; i < n; i++) {
arr2[i] = minHeap1.extract();
}
for (int i = 0; i < n; i++) {
System.out.print(arr2[i] + " ");
}
}
}

[8126, 6448, 4014, 1574, 2933, 3193, 5137, 6246, 9873, 7044]
1574 2933 3193 4014 5137 6246 6448 7044 8126 9873

[2412, 2996, 3523, 7162, 9881, 8777, 8733, 865, 3719, 4991]
[7, 0, 2, 1, 9, 5, 6, 3, 8, 4]
865 2412 2996 3523 3719 4991 7162 8733 8777 9881 

    @Override
public void change(int index, T t) {
data.set(index,t);
for (int i = 0; i < getSize(); i++) {
if (indexes.get(i) == index) {
shiftUp(i);
shiftDown(i);
}
}
}

    @Override
public T get(int index) {
if (!indexes.contains(index)) {
throw new IllegalArgumentException("索引不存在");
}
return data.get(index);
}

@Getter
public class IndexMinHeap<T extends Comparable<T>> implements IndexHeap<T> {
//存储数据
private List<T> data;
//存储数据的索引
private List<Integer> indexes;
//存储数据索引的反向索引
private List<Integer> reverse;

public IndexMinHeap(int capcity) {
data = new ArrayList<>(capcity);
indexes = new ArrayList<>(capcity);
reverse = new ArrayList<>(capcity);
}

public IndexMinHeap() {
data = new ArrayList<>();
indexes = new ArrayList<>();
reverse = new ArrayList<>();
}

public IndexMinHeap(T[] arr) {
this(arr.length);
data.addAll(Arrays.asList(arr));
for (int i = 0; i < data.size(); i++) {
indexes.add(i);
reverse.add(i);
}
for (int i = parent(getSize() - 1); i >= 0; i--) {
shiftDown(i);
}
}

@Override
public int getSize() {
return indexes.size();
}

@Override
public boolean isEmpty() {
return indexes.isEmpty();
}

@Override
public void add(int index,T t) {
data.add(index,t);
indexes.add(index);
reverse.add(index);
shiftUp(getSize() - 1);
}

@Override
public void add(T t) {
int index = data.size();
data.add(t);
indexes.add(index);
reverse.add(index);
shiftUp(getSize() - 1);
}

@Override
public T extract() {
T ret = findHeap();
swap(0,getSize() - 1);
reverse.set(indexes.get(0),0);
reverse.set(indexes.get(getSize() - 1),-1);
indexes.remove(getSize() - 1);
shiftDown(0);
return ret;
}

@Override
public int extractIndex() {
int ret = indexes.get(0);
swap(0,getSize() - 1);
reverse.set(indexes.get(0),0);
reverse.set(indexes.get(getSize() - 1),-1);
indexes.remove(getSize() - 1);
shiftDown(0);
return ret;
}

@Override
public T get(int index) {
if (index >= getSize() || reverse.get(index) == -1) {
throw new IllegalArgumentException("索引不存在");
}
return data.get(index);
}

@Override
public void change(int index, T t) {
if (index >= getSize() || reverse.get(index) == -1) {
throw new IllegalArgumentException("索引不存在");
}
data.set(index,t);
int revIndex = reverse.get(index);
shiftUp(revIndex);
shiftDown(revIndex);
}

@Override
public T findHeap() {
if (isEmpty()) {
throw new IllegalArgumentException("堆为空，不能做此操作");
}
return data.get(indexes.get(0));
}

/**
* 获取父节点的索引
* @param index
* @return
*/
private int parent(int index) {
if (index == 0) {
throw new IllegalArgumentException("索引0没有父节点");
}
return (index - 1) / 2;
}

private int leftChild(int index) {
return index * 2 + 1;
}

private int rightChild(int index) {
return index * 2 + 2;
}

private void shiftUp(int index) {
//如果当前节点不是根节点
//且父节点的值比当前节点的值要大
while (index > 0 &&
data.get(indexes.get(parent(index))).compareTo(data.get(indexes.get(index))) > 0) {
//交换父节点和当前节点的值
//并把当前节点的索引改为父节点的索引
swap(parent(index),index);
reverse.set(indexes.get(parent(index)),parent(index));
reverse.set(indexes.get(index),index);
index = parent(index);
}
}

private void shiftDown(int index) {
while (leftChild(index) < getSize()) {
int left = leftChild(index);
int right = left + 1;
int min = left;
//比较左右孩子的值哪个更小
if (right < getSize()
&& data.get(indexes.get(right)).compareTo(data.get(indexes.get(left))) < 0) {
min = rightChild(index);
}

//如果父节点的值比左右孩子的最小值还要小，略过
if (data.get(indexes.get(index)).compareTo(data.get(indexes.get(min))) <= 0) {
break;
}
//将父节点的值与左右孩子的最小值交换
//并将父节点的索引改为左右孩子最小值的索引
swap(index, min);
reverse.set(indexes.get(index),index);
reverse.set(indexes.get(min),min);
index = min;
}
}

private void swap(int x,int y) {
if (x < 0 || x >= getSize() || y < 0 || y >= getSize()) {
throw new IllegalArgumentException("索引非法");
}
int temp = indexes.get(x);
indexes.set(x,indexes.get(y));
indexes.set(y,temp);
}
}

public class IndexMinHeapMain {
public static void main(String[] args) {
int n = 10;
IndexHeap<Integer> minHeap = new IndexMinHeap<>(n);
Random random = new Random();
for (int i = 0;i < n;i++) {
minHeap.add(random.nextInt(n * 1000));
}
System.out.println(((IndexMinHeap)minHeap).getData());
Integer[] arr = new Integer[n];
for (int i = 0; i < n; i++) {
arr[i] = minHeap.extract();
}
for (int i = 1;i < n;i++) {
if (arr[i - 1] > arr[i]) {
throw new IllegalArgumentException("出错");
}
}
for (int i = 0; i < n; i++) {
System.out.print(arr[i] + " ");
}
Integer[] arr1 = new Integer[n];
for (int i = 0; i < n; i++) {
arr1[i] = random.nextInt(n * 1000);
}
System.out.println();
System.out.println("测试最小堆完成");
IndexHeap<Integer> minHeap1 = new IndexMinHeap<>(arr1);
System.out.println(((IndexMinHeap)minHeap1).getData());
System.out.println(((IndexMinHeap)minHeap1).getIndexes());
minHeap1.extract();
n--;
System.out.println(minHeap1.get(8));
minHeap1.change(3,9527);
Integer[] arr2 = new Integer[n];
for (int i = 0; i < n; i++) {
arr2[i] = minHeap1.extract();
}
for (int i = 0; i < n; i++) {
System.out.print(arr2[i] + " ");
}
}
}

[8524, 2544, 8147, 772, 4659, 3962, 6716, 9758, 9358, 3574]
772 2544 3574 3962 4659 6716 8147 8524 9358 9758

[6826, 3071, 8097, 5773, 6960, 9057, 6823, 8796, 7924, 3244]
[1, 9, 6, 3, 0, 5, 2, 7, 8, 4]
7924
3244 6823 6826 6960 7924 8097 8796 9057 9527 

@Getter
public class ArrayIndexMinHeap<T extends Comparable<T>> implements IndexHeap<T> {
private T[] data;
private int[] indexes;
private int count;
private int capacity;
private int[] reverse;

@SuppressWarnings("unchecked")
public ArrayIndexMinHeap(int capacity) {
this.capacity = capacity;
data = (T[]) new Comparable[capacity + 1];
indexes = new int[capacity + 1];
reverse = new int[capacity + 1];
Arrays.fill(reverse,0);
count = 0;
}

public ArrayIndexMinHeap(T[] arr) {
this(arr.length);
for (int i = 0; i < arr.length; i++) {
data[i + 1] = arr[i];
indexes[i + 1] = i + 1;
reverse[i + 1] = i + 1;
count++;
}
for (int i = parent(count); i >= 1; i--) {
shiftDown(i);
}
}

/**
* 获取父节点的索引
* @param index
* @return
*/
private int parent(int index) {
if (index == 1) {
throw new IllegalArgumentException("索引1没有父节点");
}
return index / 2;
}

private int leftChild(int index) {
return index * 2;
}

private int rightChild(int index) {
return index * 2 + 1;
}

private void swap(int x,int y) {
if (x < 1 || x > count || y < 1 || y > count) {
throw new IllegalArgumentException("索引非法");
}
int temp = indexes[x];
indexes[x] = indexes[y];
indexes[y] = temp;
}

private void shiftUp(int index) {
//如果当前节点不是根节点
//且父节点的值比当前节点的值要大
while (index > 1 &&
data[indexes[parent(index)]].compareTo(data[indexes[index]]) > 0) {
//交换父节点和当前节点的值
//并把当前节点的索引改为父节点的索引
swap(parent(index),index);
reverse[indexes[parent(index)]] = parent(index);
reverse[indexes[index]] = index;
index = parent(index);
}
}

private void shiftDown(int index) {
while (leftChild(index) <= count) {
int left = leftChild(index);
int right = left + 1;
int min = left;
//比较左右孩子的值哪个更小
if (right <= count
&& data[indexes[right]].compareTo(data[indexes[left]]) < 0) {
min = rightChild(index);
}

//如果父节点的值比左右孩子的最小值还要小，略过
if (data[indexes[index]].compareTo(data[indexes[min]]) <= 0) {
break;
}
//将父节点的值与左右孩子的最小值交换
//并将父节点的索引改为左右孩子最小值的索引
swap(index, min);
reverse[indexes[index]] = index;
reverse[indexes[min]] = min;
index = min;
}
}

private boolean contain(int index) {
return reverse[index + 1] == 0;
}

@Override
public int getSize() {
return count;
}

@Override
public boolean isEmpty() {
return count == 0;
}

@Override
public void add(int index, T t) {
if (count + 1 > capacity) {
throw new IllegalArgumentException("数组越界");
}
if (index + 1 < 1 || index + 1 > capacity) {
throw new IllegalArgumentException("数组越界");
}
index++;
data[index] = t;
indexes[count + 1] = index;
reverse[index] = count + 1;
count++;
shiftUp(count);
}

@Override
public void add(T t) {
throw new RuntimeException("不提供该方法");
}

@Override
public T extract() {
T ret = findHeap();
swap(1,count);
reverse[indexes[1]] = 1;
reverse[indexes[count]] = 0;
count--;
shiftDown(1);
return ret;
}

@Override
public int extractIndex() {
if (isEmpty()) {
throw new IllegalArgumentException("堆为空，不能做此操作");
}
int ret = indexes[1] - 1;
swap(1,count);
reverse[indexes[1]] = 1;
reverse[indexes[count]] = 0;
count--;
shiftDown(1);
return ret;
}

@Override
public T get(int index) {
if (contain(index)) {
throw new IllegalArgumentException("索引不存在");
}
return data[index + 1];
}

@Override
public void change(int index, T t) {
if (contain(index)) {
throw new IllegalArgumentException("索引不存在");
}
index++;
data[index] = t;
int revIndex = reverse[index];
shiftUp(revIndex);
shiftDown(revIndex);
}

@Override
public T findHeap() {
if (isEmpty()) {
throw new IllegalArgumentException("堆为空，不能做此操作");
}
return data[indexes[1]];
}
}

public class IndexMinHeapMain {
public static void main(String[] args) {
int n = 10;
IndexHeap<Integer> minHeap = new ArrayIndexMinHeap<>(n);
Random random = new Random();
for (int i = 0;i < n;i++) {
minHeap.add(i,random.nextInt(n * 1000));
}
Stream.of(((ArrayIndexMinHeap)minHeap).getData())
.filter(data -> data != null)
.forEach(data -> System.out.print(data + " "));
System.out.println("");
Integer[] arr = new Integer[n];
for (int i = 0; i < n; i++) {
arr[i] = minHeap.extract();
}
for (int i = 1; i < n; i++) {
if (arr[i - 1] > arr[i]) {
throw new IllegalArgumentException("出错");
}
}
for (int i = 0; i < n; i++) {
System.out.print(arr[i] + " ");
}
Integer[] arr1 = new Integer[n];
for (int i = 0; i < n; i++) {
arr1[i] = random.nextInt(n * 1000);
}
System.out.println();
System.out.println("测试最小堆完成");
IndexHeap<Integer> minHeap1 = new ArrayIndexMinHeap<>(arr1);
Stream.of(((ArrayIndexMinHeap)minHeap1).getData())
.filter(data -> data != null)
.forEach(data -> System.out.print(data + " "));
System.out.println();
for (int i = 1; i <= minHeap1.getSize(); i++) {
System.out.print(((ArrayIndexMinHeap)minHeap1).getIndexes()[i] + " ");
}
System.out.println();
minHeap1.extract();
n--;
Stream.of(((ArrayIndexMinHeap)minHeap1).getData())
.filter(data -> data != null)
.forEach(data -> System.out.print(data + " "));
System.out.println();
for (int i = 1; i <= minHeap1.getSize(); i++) {
System.out.print(((ArrayIndexMinHeap)minHeap1).getIndexes()[i] + " ");
}
System.out.println();
for (int i = 1; i <= minHeap1.getSize(); i++) {
System.out.print(((ArrayIndexMinHeap)minHeap1).getReverse()[i] + " ");
}
System.out.println();
System.out.println(minHeap1.get(8));
minHeap1.change(3,9527);
Integer[] arr2 = new Integer[n];
for (int i = 0; i < n; i++) {
arr2[i] = minHeap1.extract();
}
for (int i = 0; i < n; i++) {
System.out.print(arr2[i] + " ");
}
}
}

9003 7873 1421 7191 8573 799 471 2359 5132 6041
471 799 1421 2359 5132 6041 7191 7873 8573 9003

3438 4971 9353 8743 2528 8294 4976 8498 7307 7161
5 1 7 9 2 6 3 8 4 10
3438 4971 9353 8743 2528 8294 4976 8498 7307 7161
1 2 7 9 10 6 3 8 4
1 2 7 9 0 6 3 8 4
7307
3438 4971 4976 7161 7307 8294 8498 9353 9527 

0
0 收藏

0 评论
0 收藏
0