kanetaiの二次記憶装置

プログラミングに関するやってみた、調べた系のものをQitaに移して、それ以外をはてブでやる運用にしようと思います。http://qiita.com/kanetai

4 Values whose Sum is 0(PKU No.2785)

4 Values whose Sum is 0(PKU No.2785)

素数nの整数のリストA,B,C,Dが与えられたとき、各リストから1つずつ取り出したときの和が0となるような組み合わせの個数を求めよ。ただし、1つのリストに同じ値が重複している場合、それらは異なるものとして扱い、組み合わせの数を計算する。
制約
 1\leq n \leq 4000

与えられるリストの要素(整数)の値  \leq 2^{28}

アルゴリズム

全探索は、 n^4 = 4000^4 = 256,000,000,000,000で無謀。
http://kanetai.hatenablog.com/entry/20110504/1304513560と同様の半分全列挙戦略をとる。

まず、 CD=\{ c+d|c\in C,d\in D\} を求めて ( |CD|=n^2 )ソートしておく。
CDの中で、 \{ -(a+b)|a\in A,b\in B \} と一致するc+dの数を求める。
ソートされているので2分探索できる。

計算量は、CD生成に O(n^2)、ソートに O(n^2 \log n^2 ) = O(2n^2 \log n) = O(n^2 \log n)
 -(a+b)の組み合わせ n^2について、CDを2分探索 (O(\log n^2) = O(\log n))するのでこちらも O(n^2 \log n)
従って、全体で O(n^2\log n)

コード

C++なら、CD中の-(a+b)に一致する値の数を求めるのに、upper_bound()-lower_bound()が使える。
Javaには無いのかな?

import java.util.*;
import java.lang.*;
public class pku2785 {
	static Scanner scanner;
	int lower_bound(int[] array, int k){
		int lb = -1, ub = array.length;
		//解の存在範囲が1より大きい間、反復
		while(ub - lb > 1){
			int mid = (lb + ub) / 2;
			if( array[mid] >= k )	//midが条件を満たせば、解の存在範囲は(lb,mid]
				ub = mid;
			else				//midが条件を満たさなければ、解の存在範囲は(mid,ub]
				lb = mid;
		}
		return ub; //lb + 1 = ub
	}
	
	int upper_bound(int[] array, int k){
		int lb = -1, ub = array.length;
		//解の存在範囲が1より大きい間、反復
		while(ub - lb > 1){
			int mid =(lb + ub) / 2;
			if( array[mid] <= k )	//midが条件を満たせば、解の存在範囲は[mid,ub)
				lb = mid;
			else				//midが条件を満たさなければ、解の存在範囲は[lb,mid)
				ub = mid;
		}
		return ub;
	}
	
	public static void main(String[] args) {
		scanner = new Scanner(System.in);
		new pku2785();
	}
	pku2785(){
		int n = scanner.nextInt();
		int [] A = new int[n];
		int [] B = new int[n];
		int [] C = new int[n];
		int [] D = new int[n];
		for(int i=0; i<n; ++i){
			A[i] = scanner.nextInt();
			B[i] = scanner.nextInt();
			C[i] = scanner.nextInt();
			D[i] = scanner.nextInt();
		}
		System.out.println( solve(n,A,B,C,D) ); 
	}
	long solve(int n, int[] A, int[] B, int[] C, int[] D){
		int[] CD = new int[n*n];
		long res = 0;
		for(int i=0; i<n; ++i)
			for(int j=0; j<n; ++j)
				CD[ i*n + j ] = C[i] + D[j];
		Arrays.sort( CD );
		for(int i=0; i<n; ++i){
			for(int j=0; j<n; ++j){
				int nab = -( A[i] + B[j] );
				res += this.upper_bound(CD, nab) - this.lower_bound(CD, nab);
			}
		}
		return res;
	}

}