CodeIQ:「スクエア・カルテット」問題

私自身が表題の問題を解いた時のプログラムについて解説します。
問題の詳細は「スクエア・カルテット」問題CodeIQ)を参照してください。

問題の概要

問題を引用します。

2つの自然数の組 (a, b) が与えられたとき、自然数 x, y に関する次の方程式を考えます。
x2+a2=y2+b2 … (※)

(中略)
自然数の組 (a, b) に対し、方程式(※)の全ての解の x + y の和を F(a, b) と定義します。
(中略)
標準入力から、半角空白区切りで 2 つの自然数 a, b(1 ≦ a < b ≦ 105)が与えられます。
標準出力に F(a, b) の値を出力するプログラムを書いてください。

私のプログラム

C++で解答しています。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/**
 *  x^2 + a^2 = y^2 + b^2 ...(1)
 *  (1)を満たす全ての(x,y)の合計を求める。
 *
 *  (1)は次の式に等しい
 *  x^2-y^2 = b^2-a^2
 *  (x+y)(x-y) = b^2-a^2 ...(1')
 *
 *  (1')を満たす(x+y)は次のA,Bを満たす値に等しい。
 *  (A) (1')の右辺のそ因数の任意に組み合わせによる積
 *  (B) (A)で残った素因数を組み合わせた値以上の(A)
 *
 *  (x+y)はすでに(1)を満たす(x,y)の組み合わせを合計したものであるから、上記A,Bを満たす
 *  素因数の積の集合を合計すれば良い
 */
 
#include <stdio.h>
#include <stdlib.h>
#include <vector>
#include <set>
#include <algorithm>
 
using namespace std;
 
// 素因数分解する
vector< unsigned long long > insubunkai(unsigned long long n){
    unsigned long long m = n;
    vector< unsigned long long > ret;
 
    for(unsigned long long i=2; m>1 ;){
        if(m%i == 0){
            ret.push_back(i);
            m /= i;
        }
        else{
            i++;
        }
    }
 
    return ret;
}
 
// x,yが最初の式を満たすかをチェックする
// n: b^2-a^2
// XpY: x+y
int checkXY(unsigned long long n, unsigned long long XpY){
    unsigned long long XmY = n / XpY;
    unsigned long long x = (XpY + XmY)/2;
    unsigned long long y = (XpY - XmY)/2;
    int ret = 0;
    if((x*x-y*y == n) && x && y){
        ret = 1;
    }
//  printf("[%s]: x=%llu, y=%llu, X^2*y^2=%llu\n", (ret?"OK":"NG"), x, y, x*x-y*y);
 
    return ret;
}
 
// (x+y)を重複なく取得する
// v: b^2-a^2の素因数の集合
// n: b^2-a^2
set< unsigned long long > getXplusY(vector< unsigned long long >& v, unsigned long long n){
    set< unsigned long long > ret;
    do{
        for(int r=1; r<=v.size(); r++){
            unsigned long long XpY = 1;
            for(unsigned long long i=0; i<r; i++){
                XpY *= v[i];
            }
            unsigned long long XmY = n / XpY;
            if((XpY >= XmY) && checkXY(n, XpY)){
                ret.insert(XpY);
            }
        }
    }while(next_permutation(v.begin(), v.end()));
    return ret;
}
 
unsigned long long  getSum(set< unsigned long long  > s){
    unsigned long long  sum = 0;
    for(set< unsigned long long  >::const_iterator i=s.begin(); i!=s.end(); i++){
        sum += *i;
    }
    return sum;
}
 
int main(int argc, char* argv[]){
    char str[1024] = {0};
    while( fgets(str, sizeof(str), stdin) != NULL ){
        unsigned long long a;
        unsigned long long b;
 
        sscanf(str, "%llu %llu", &a, &b);
        unsigned long long  n = b*b-a*a;
 
        vector< unsigned long long  > v = insubunkai(n);
        set< unsigned long long  > s = getXplusY(v, n);
        printf("%llu\n", getSum(s));
    }
 
    return 0;
}

解説

この問題は結構難しいです。ポイントは2点あります。
1つは条件を満たすX、Yを列挙する方法、もう一つは計算量です。
ちなみに私のプログラムは正しい結果を出力しますが、力技で解決している部分があってあまり美しいとは言えません。

基本的な考え方

これはプログラムの先頭に記述したコメントのとおりです。
まず、x2+a2=y2+b2は次のように変形できます。
  x2-y2=b2-a2

さらに次のように変形できます。
  (x+y)(x-y)=b2-a2 …(a)

(a)から(x+y)はb2-a2の素因数を2グループに分けて、その積が大きい方であることは明らかです。問題のF(a,b)は条件を満たす(x+y)の合計なので前述の条件を満たす値を列挙し、それを合計すれば良いということになります。

b2-a2

これは入力値から容易に計算できます。90行目でいきなり計算してしまいます。

素因数分解

insubunkai()でb2-a2を素因数分解しています。
ロジックは単純である数で割り切れる間はその数で割り続け、割り切れなくなったら1大きい値で割るという処理を繰り返して素因数分解しています。割り切れた場合、割った値を素因数として集めておきます。
素数以外で割るという処理が入りますが、無駄な処理が入るだけで正しい結果になります。例えば2、3ときて4で割る場合、すでに2で割れる分は全部処理しているので何もせず5に進むからです。
無駄を省きたければ1000以下の素数表を使用すれば性能を改善できます。ただ、今回の入力値ではb2-a2 < 1000000は明らかなので問題ない(最大で高々1000回しかループしないので頑張って高速化しなくて良い)と判断しました。

(x+y)を列挙する

次に素因数の集合から(x+y)を列挙します。getXplusY()がその処理をする関数です。
ここで使っているnext_permutation()はSTLの<algorithm>にあるコンテナ要素の順列を全て列挙してくれる関数です。本当は組み合わせを列挙してくれる関数があれば全く無駄がなかったのですが、STLにはなかったので無駄を承知でnext_permutation()を使用しました。next_permutation()は線形時間で処理できると説明されていたので今回の問題なら無駄を含めても十分高速に処理できます。

next_permutation()のdo-whileループの中ではforループで最初の1個だけ、次は先頭から2個、3個、…と素因数を取り出して積を求めて(x+y)の候補を生成しています。
同時にb2-a2をその候補の値で割って(x-y)を求めて条件、つまり、(x+y) > (x-y)になる候補だけをピックアップしています。
ピックアップした結果をsetに入れているのは同じ値が重複した場合、それを無視するためです。

ここで1つ妙な処理checkXY()があります。次はこれについて説明します。

checkXY()

これは何をしているのかというと、(x+y)のxとyが本当に正しい値かをチェックする処理です。私自身なぜかがわかっていないのですが、これまでの処理で求めたx、yが元の式x2-y2=b2-a2を満たさない場合があるため、元の式と結果を照らし合わせて正しい値だけを選別しているというわけです。

getSum()

これは特に問題などないでしょう。getXplusY()で求めた(x+y)の集合から合計を求めて答えを出しています。
setは重複を許さないので正しい答えが求まります。

雑感

以上ですが、このプログラムはかなり無駄な処理や強引なところが多いです。
ちなみに、この問題を見た時点でC言語、JavaScript、PHPは選択肢にありませんでした(この問題をやった時点で私はPythonを覚えていません)。C言語は可変長の配列と集合を扱う便利な機能が無いためで、JSとPHPは主に性能です。JavaよりC++を優先したのは性能を少しでも稼ぎたかったからだったと思いますが、STLの<algorithm>にnext_permutation()があったのはラッキーでした。Javaだと組み合わせを列挙する処理を自作しなければならず大分苦労したのではないかと思います。