RCTF 2020 Writeup easy_f(x)

2020年5/30 ~ 6/1 に開催されたRCTF 2020に参加して、なんとか1問(easy_f(x) )解けたのでそのWriteupを書きます.

easy_f(x)

Proof of Work

124.156.140.90 2333につないでみると、次のような応答が返ってくる。

sha256(XXXX+ydl3MyCipSFFKKAl) == 7be4d6e89b8c5d6c33cade48c8bff97bb1670c340536d660f47dc99b835d97df
Give me XXXX:

よく登場するProof of Workですね。この例ではydl3MyCipSFFKKAlを加えた後のsha256ハッシュ値

7be4d6e89b8c5d6c33cade48c8bff97bb1670c340536d660f47dc99b835d97df となる最初の4文字を見つけろというもの。

server.pyを見ると、英数字のみ(string.ascii_letters+string.digits)が使用されているため、それらを用いて総当たりするだけ。最悪の場合でもそこまで時間はかからない。

Calculating Check

ここがかなり大変な部分。。

まず、情報を整理するとKはランダムな数値が513(check+ Pbits)格納されているリスト。そしてここでの目標はこのランダムな数値checkを求めること。

関数fではKの長さである513回r = (r + k[i] * pow (x, i, m)) % mの処理が行われているため、

f(x) = rは次の式で表せられる。

$$ f(x) = check + K_1*x+K_2*x^2 + .. + K_{512}*x^{512} \;mod \;m $$

ここで、未知数は $K_i (i = 1,2,..,512)$ であるため、 checkがなければ、512個の式をもらえれば全てのKの値が求まる。しかし、checkが邪魔となる。そこで式を$512*2 = 1024$個受け取り、互いに引いてcheckを消して$K_1$から$K_{512}$の値を求めれば、

$$ check = f(x) - K_1*x - K_2*x^2 .. \;mod \;m $$

という計算式でcheckの値を求められる。

以上の説明を式にすると次のようになる。

$$ f_1(x_1) - f_{513}(x_{513}) = K_1(x_1- x_{513}) + K_2(x_1^2 - x_{513}^2) + .. + K_{513}(x_1^{512} - x_{513} ^ {512}) \; mod \;m\\ f_2(x_2) - f_{514}(x_{514}) = K_1(x_2- x_{514}) + K_2(x_2^2 - x_{514}^2) + .. + K_{513}(x_2^{512} - x_{514} ^ {512}) \; mod \;m\\ ..\\ ..\\ f_{512}(x_{512}) - f_{1024}(x_{1024}) = K_1(x_{512}- x_{1024}) + K_2(x_{512}^2 - x_{1024}^2) + .. + K_{513}(x_{512}^{512} - x_{1024} ^ {512}) \; mod \;m $$

これによって、未知数 $K_i$ が512個である連立方程式となったため、解くことが可能である。

ただ、自分はpythonでmod上での連立方程式(特に今回のような大規模なもの)を解く方法を知らなったので、ローカルにインストールしていたSageMathを使って解きました。それでも解くのに4分くらいかかったので、ミスがあった時に修正して試すのにかなり時間がかかってしまって大変でした。。

solverは以下。

from sage.all_cmdline import * 
import os,random,sys,string
from hashlib import sha256
from Crypto.Util.number import *
from minipwn import *
from tqdm import tqdm
import time

Pbits = 512

def Get_Hash():
    recv_m = io.recvline().strip()
    print(io.recvuntil("Give me XXXX:"))
    str_4_ = recv_m[recv_m.find(b"+")+1:recv_m.find(b")")].strip()
    sha256_hash = recv_m.split(b"==")[1].strip()
    print("str_4_: {}".format(str_4_))
    print("sha256_hash: {}".format(sha256_hash))
    return str_4_, sha256_hash


def Proof(str_4_, sha256_hash):
    Cand_list = list(string.ascii_letters+string.digits)
    for i in tqdm(range(len(Cand_list))):
        c_1 = Cand_list[i]
        for c_2 in Cand_list:
            for c_3 in Cand_list:
                for c_4 in Cand_list:
                    cand_4 = c_1 + c_2 + c_3 + c_4
                    cand =  cand_4.encode() + str_4_
                    sha_hex = sha256(cand).hexdigest()
                    if sha_hex == sha256_hash:
                        print("Find! cand: {}".format(cand))
                        print("sha_hex: {}".format(sha_hex))
                        print("send: {}".format(cand_4))
                        io.sendline(cand_4)
                        return 
    print("Fail to find..")
    return 0

def Get_Values(times):
    recv_m = io.recvline()
    print("recv_m: {}".format(recv_m))
    M = int(recv_m.split(b"=")[1].strip())
    print("M: {}".format(M))
    print(io.recvuntil("How many f(x) do you want?"))
    io.sendline(str(times))
    print(io.recvline())
    x_list_1 = []
    x_list_2 = []
    r_list_1 = []
    r_list_2 = []
    half_times = times//2
    for i in tqdm(range(times)):
        recv_m = io.recvline().strip()
        x = int(recv_m[recv_m.find(b"(")+1: recv_m.find(b")")].strip())
        r = int(recv_m.split(b"=")[1].strip())
        #print("x: {}".format(x))
        #print("r: {}".format(r))
        if i < half_times:
            x_list_1.append([pow(x, j, M) for j in range(1, Pbits+1)])
            r_list_1.append(r)
        else:
            x_list_2.append([pow(x, j, M) for j in range(1, Pbits+1)])
            r_list_2.append(r)
    return M, x_list_1, x_list_2, r_list_1, r_list_2


def Find_Check():
    R = IntegerModRing(M)
    r_diff_list = [r_list_1[i] - r_list_2[i] for i in range(len(r_list_1))]
    x_diff_list = [[x_list_1[i][j] - x_list_2[i][j] for j in range(len(x_list_1))] for i in range(len(x_list_1[0]))]
    
    r_Vec = vector(R, r_diff_list)
    x_Mat = Matrix(R, x_diff_list)

    print("Now Solving...")
    start = time.time()
    K_list = x_Mat.solve_right(r_Vec)
    end = time.time()
    print("Elapsed: {}".format(end - start))
    #print("K_list: {}".format(K_list))
    print("K_len: {}".format(len(K_list)))

    # check = f(x) - k1*x - k2*x^2 - k2*x^3 ..
    A = r_list_1[0]
    B = sum([(K_list[i]*x_list_1[0][i])%M for i in range(len(x_list_1))])%M
    print("A: {}".format(A))
    print("B: {}".format(B))

    check = (A - B)%M
    print("check: {}".format(check))
    io.sendline(check)


io = remote("124.156.140.90", 2333)

str_4_, sha256_hash = Get_Hash()
Proof(str_4_, sha256_hash.decode())
M, x_list_1, x_list_2, r_list_1, r_list_2 = Get_Values(times = 1024)
Find_Check()

while True:
    print(io.recvline())

Flag: RCTF{A_e4siest_sh4mlr_s3cr3t_sh4rIng!!}