diff --git a/yt_dlp/aes.py b/yt_dlp/aes.py index f52b992df..60cdeb74e 100644 --- a/yt_dlp/aes.py +++ b/yt_dlp/aes.py @@ -178,7 +178,7 @@ def aes_encrypt(data, expanded_key): data = sub_bytes(data) data = shift_rows(data) if i != rounds: - data = mix_columns(data) + data = list(iter_mix_columns(data, MIX_COLUMN_MATRIX)) data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]) return data @@ -197,7 +197,7 @@ def aes_decrypt(data, expanded_key): for i in range(rounds, 0, -1): data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES]) if i != rounds: - data = mix_columns_inv(data) + data = list(iter_mix_columns(data, MIX_COLUMN_MATRIX_INV)) data = shift_rows_inv(data) data = sub_bytes_inv(data) data = xor(data, expanded_key[:BLOCK_SIZE_BYTES]) @@ -375,49 +375,23 @@ def xor(data1, data2): return [x ^ y for x, y in zip(data1, data2)] -def rijndael_mul(a, b): - if a == 0 or b == 0: - return 0 - return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF] - - -def mix_column(data, matrix): - data_mixed = [] - for row in range(4): - mixed = 0 - for column in range(4): - # xor is (+) and (-) - mixed ^= rijndael_mul(data[column], matrix[row][column]) - data_mixed.append(mixed) - return data_mixed - - -def mix_columns(data, matrix=MIX_COLUMN_MATRIX): - data_mixed = [] - for i in range(4): - column = data[i * 4: (i + 1) * 4] - data_mixed += mix_column(column, matrix) - return data_mixed - - -def mix_columns_inv(data): - return mix_columns(data, MIX_COLUMN_MATRIX_INV) +def iter_mix_columns(data, matrix): + for i in (0, 4, 8, 12): + for row in matrix: + mixed = 0 + for j in range(4): + # xor is (+) and (-) + mixed ^= (0 if data[i:i + 4][j] == 0 or row[j] == 0 else + RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[data[i + j]] + RIJNDAEL_LOG_TABLE[row[j]]) % 0xFF]) + yield mixed def shift_rows(data): - data_shifted = [] - for column in range(4): - for row in range(4): - data_shifted.append(data[((column + row) & 0b11) * 4 + row]) - return data_shifted + return [data[((column + row) & 0b11) * 4 + row] for column in range(4) for row in range(4)] def shift_rows_inv(data): - data_shifted = [] - for column in range(4): - for row in range(4): - data_shifted.append(data[((column - row) & 0b11) * 4 + row]) - return data_shifted + return [data[((column - row) & 0b11) * 4 + row] for column in range(4) for row in range(4)] def shift_block(data):