徒然

思いついたら書きます

AtCoder practice contest B - Interactive SortingをPythonで解く

はじめに

AtCoder practice contest B - Interactive SortingPythonで解いた記事が見つからないので、簡単に解説してみます。

インタラクティブ問題について

クエリを受け取るときはいつも通りinput()を使えばよいです。

ただしクエリを送信する際はprint()をそのまま使えば良いというわけではなく、以下のようにflushする必要があることに注意が必要です。

print(ans, flush = True)

インタラクティブな部分を関数化する

今回は比較部分がインタラクティブとなっているので、その部分を関数化します。

今回は簡単にクエリの返答をそのまま返す関数にしました。

def compare(a, b):
    print("? " + a, b)
    return input()

マージソートの実装

データセット1,2はマージソートします。100回の比較で十分間に合います。

 【Python】ソートアルゴリズム - Qiitaからコードを拝借し、比較部分だけ改変して実装しました。

def merge_sort(l):
    if len(l) <= 1:
        return l

    mid = len(l) // 2
    left = l[:mid]
    right = l[mid:]

    left = merge_sort(left)
    right = merge_sort(right)

    return merge(left, right)


def merge(left, right):
    merged = 
    l_i, r_i = 0, 0

    while l_i < len(left) and r_i < len(right):

        print("? "+left[l_i], right[r_i])
        if input() == "<":  # ここだけソースより変更
            merged.append(left[l_i])
            l_i += 1
        else:
            merged.append(right[r_i])
            r_i += 1

    if l_i < len(left):
        merged.extend(left[l_i:])
    if r_i < len(right):
        merged.extend(right[r_i:])
    return merged

データセット3向けの実装

この問題で厄介なのはデータセット3のときです。5個のボールのソートに7回しか使えないので、マージソートでは間に合いません。

よってN=5のときに独自の実装をします。方法は重さの異なるN枚のコイン [楽しいクイズの発信基地!クイズ大陸]を参考にしました。

N=5のときの比較方法

まず2個を選んで比較、残りから2個を選んで比較します。またそのうち軽い方同士を比較し、軽かったほうの組を軽い順に(a, b)、重い組を(c, d)とし、比較してないものをeとします。

ここまででa<bかつa<c<dということがわかります。

次にceを比較します。この結果によって分岐します。

 c<eのとき

次に d,eを比較します。eのほうが軽いとき、d,eを入れ替えます。するとa<bかつa<c<d<eとなります。よってbdと比較→ceと比較、と二分探索風にすることで、7回の比較で順番がすべてわかります。

 e<cのとき

こっちは少し面倒です。まずa,eを比較します。a<eのとき、b,cを比較します。

bが軽ければa<b<c<dかつa<e<c<dとなるので、b,eを比較すればよいです。

またcが軽ければa<e<c<dかつa<e<c<bとなるのでbdを比較すればよいです。

e<aのときはb,cを比較します。

bが軽ければe<a<b<c<dで確定です。この場合だけ最も少ない6回の比較で済みます。

またcが軽ければe<a<c<bかつe<a<c<dとなるのでbdを比較すればよいです。

 実装

上につらつらと書いたことを実装します。elifなどを使わずに愚直に実装するとこんな感じです。

def sort5(l): 
    a, b, c, d, e = l 
    if compare(a, b) == ">": 
        a, b = b, a 
    if compare(c, d) == ">": 
        c, d = d, c 
    if compare(a, c) == ">": 
        a, b, c, d = c, d, a, b 
    if compare(c, e) == "<": 
        if compare(d, e) == ">": 
            d, e = e, d 
        if compare(b, d) == "<": 
            if compare(b, c) == "<": 
                l = [a, b, c, d, e] 
            else: 
                l = [a, c, b, d, e] 
        else: 
            if compare(b, e) == "<": 
                l = [a, c, d, b, e] 
            else: 
                l = [a, c, d, e, b] 
    else: 
        if compare(a, e) == "<": 
            if compare(b, c) == "<": 
                if compare(b, e) == "<": 
                    l = [a, b, e, c, d] 
                else: 
                    l = [a, e, b, c, d] 
            else: 
                if compare(b, d) == "<": 
                    l = [a, e, c, b, d] 
                else: 
                    l = [a, e, c, d, b] 
        else: 
            if compare(b, c) == "<": 
                l = [e, a, b, c, d] 
            else: 
                if compare(b, d) == "<": 
                    l = [e, a, c, b, d] 
                else: 
                    l = [e, a, c, d, b] 
    return l

まとめ

最後に上記のコードをまとめます。文字リストはstring.ascii_uppercaseから取り出し、リスト化して関数に投げた後、返り値を結合して"! "を頭につけてprintします。

最後にflushするのを忘れてなければACするはずです。

from string import ascii_uppercase 


def compare(a, b): 
    print("? " + a, b, flush = True) 
    return input() 


def sort5(l): 
    a, b, c, d, e = l 
    if compare(a, b) == ">": 
        a, b = b, a 
    if compare(c, d) == ">": 
        c, d = d, c 
    if compare(a, c) == ">": 
        a, b, c, d = c, d, a, b 
    if compare(c, e) == "<": 
        if compare(d, e) == ">": 
            d, e = e, d 
        if compare(b, d) == "<": 
            if compare(b, c) == "<": 
                l = [a, b, c, d, e] 
            else: 
                l = [a, c, b, d, e] 
        else: 
            if compare(b, e) == "<": 
                l = [a, c, d, b, e] 
            else: 
                l = [a, c, d, e, b] 
    else: 
        if compare(a, e) == "<": 
            if compare(b, c) == "<": 
                if compare(b, e) == "<": 
                    l = [a, b, e, c, d] 
                else: 
                    l = [a, e, b, c, d] 
            else: 
                if compare(b, d) == "<": 
                    l = [a, e, c, b, d] 
                else: 
                    l = [a, e, c, d, b] 
        else: 
            if compare(b, c) == "<": 
                l = [e, a, b, c, d] 
            else: 
                if compare(b, d) == "<": 
                    l = [e, a, c, b, d] 
                else: 
                    l = [e, a, c, d, b] 
    return l 


def merge_sort(l): 
    if len(l) <= 1: 
        return l 

    mid = len(l) // 2 
    left = l[:mid] 
    right = l[mid:] 

    left = merge_sort(left) 
    right = merge_sort(right) 

    return merge(left, right) 


def merge(left, right): 
    merged = [ ]
    l_i, r_i = 0, 0 

    while l_i < len(left) and r_i < len(right): 

        print("? "+left[l_i], right[r_i]) 
        if input() == "<": 
            merged.append(left[l_i]) 
            l_i += 1 
        else: 
            merged.append(right[r_i]) 
            r_i += 1 

    if l_i < len(left): 
        merged.extend(left[l_i:]) 
    if r_i < len(right): 
        merged.extend(right[r_i:]) 
    return merged 


n, q = map(int, input().split()) 
ans = list(ascii_uppercase[:n]) 
ans = merge_sort(ans) if n != 5 else sort5(ans) 

print("! "+"".join(ans), flush=True) 

 (2020/04/13 "[ ]"が消えるはてなの不具合(?)があったので修正)

 おわり