【 Effective Python, 2nd Edition 】ブロッキング I/O ( blocking I/O ) とスレッド ( thread ) を利用しているプログラムを、asyncio 組み込みモジュールを利用してコルーチン ( coroutine ) と非同期 I/O ( asyncronous I/O ) を利用したプログラムにリファクタリング ( refactoring ) しよう! 投稿一覧へ戻る
Published 2020年9月27日22:00 by mootaro23
SUPPORT UKRAINE
- Your indifference to the act of cruelty can thrive rogue nations like Russia -
さて、コルーチンを利用した fan-out、fan-in パターン実装を取り上げた記事 を読んで、「でも既存のプログラムをコルーチンを利用した非同期プログラムに移行するのは大変だ」という感想を持った方は多いようです。
Python には言語自体に非同期実行 ( asynchronous execution ) プログラムを記述するための仕様がスマートに組み込まれています。
それによって、スレッド ( thread ) を利用してブロッキング I/O ( blocking I/O ) を処理しているコードを、スムーズにコルーチン ( coroutine ) と asyncronous I/O を利用した実装にリファクタリング ( refactoring ) することが出来るようになっているんです。
今回は、そういったリファクタリングの流れを、TCP ベースでサーバーとクライアント間でデータをやり取りしながら行う数当てゲーム ( guessing a number game ) を題材に見ていきたいと思います。
このゲームは、クライアントがサーバーに対して数の範囲を提示し、その範囲内で「答え」となる数をサーバーが当てる、というものです。
ゲームの流れは、
となっていて、1回のセッション(ゲーム)の中で指定した回数だけ 2: - 5: が繰り返されます (正解が出ればその時点で終了)。
まぁ、このサーバーはクライアントからリクエストがある都度、提示された数の範囲内でランダムに数値を生成しそれを返すだけですので guessing game とは言い過ぎですが、あくまでも目的は「非同期プログラムへのリファクタリング」ですのでご了承ください。
さて、このようなクライアント/サーバー ( client / server ) システムを構築する最も一般的な方法は、ブロッキング I/O ( blocking I/O ) とスレッド ( thread ) を利用することでしょう。
まずは、クライアントとサーバー間でメッセージをやり取りするための機能を提供するヘルパークラスを定義しましょう。
やり取りするメッセージには処理内容を指示するコマンド文字列が含まれています。
続けて、サーバーを 1 つのクラスとして定義します。
クライアントも 1 つのクラスとして実装します。
このサーバーでは、クライアントからの接続要求に応える処理関数 ( handle_connection() ) をワーカースレッドで実行することで、複数のクライアントに応えることが可能になっています。
クライアントはメインスレッドで動作します。
さぁ、準備は出来ました。
サーバーを稼動し、クライアントを立ち上げて実行してみましょう!
期待通りに実行されていますね!
でも安心している場合じゃぁありません、ここからが今回の記事の本題です。
ここまでの例はブロッキング I/O とスレッドを利用した「一般的」な実装です。
果たしてこの「既存」のコードを async、await、asyncio を利用して、コルーチン ( coroutine ) と asyncronous I/O による非同期実行プログラムにリファクタリングするにはどれ位の手間がかかってしまうのでしょうか?
ではやっていきましょう!
どの部分を変更したのかが判断しやすいように、元のコードに手を加えた部分には「追加」、「変更」等のコメントを入れておきます。
まずは ConnectionBase クラスで提供していた I/O 処理メソッド ( send() と receive() ) をコルーチンに書き換え、ブロッキング I/O メソッドの代わりに、接続ソケットに対して非同期読み込み、書き込みを実行可能な StreamReader クラスインスタンス、StreamWriter クラスインスタンスを利用するようにします。
クラス名もそれっぽく AsyncConnectionBase と変更しましょう。
サーバー側クラスが AsyncConnectionBase クラスから派生するようにします。
そして、実際に I/O 処理に携わっているクラスメソッドを async def キーワードでコルーチンに変更し、send()、receive() を呼び出している箇所を await 式に変更しましょう。
実はサーバー部分の変更はこれだけなんです!
クライアント側クラスも AsyncConnectionBase クラスから派生するようにします。
続けて先程のサーバークラスと同様に、実際に I/O 処理に携わっているクラスメソッドを async def キーワードでコルーチンに変更し、send()、receive() を呼び出している箇所を await 式に変更しましょう。
また、この変更に伴って、contextlib.contextmanager() を利用してデコレートしている session() は、contextlib.asynccontextmanager() でデコレートするように変更します。
サーバーを開始する関数は asyncio 組み込みモジュールの start_server() を利用して完全に書き換える必要があります。
クライアントを実行する関数も大幅に変更する必要があります。
これは、ブロッキング socket インスタンスを利用していた箇所を対応する asyncio バージョンのものにしなければなりませんし、コルーチンを呼び出す箇所には全て await キーワードをつける必要がありますし、for や with も async バージョンに変更する必要がありますから仕方ありませんね。
また、asyncio.open_connection() を with 文で利用することが出来ないため、クライアントの処理が終了した時点で StreamWriter インスタンスを閉じることでサーバーに接続終了を通知しています。
最後に、このプログラムのエントリーポイントである関数に変更を加えましょう。
この関数もコルーチンに変更し、サーバー開始コルーチンを asyncio.create_task() を利用してイベントループで実行されるようにスケジューリングします。
これによって、サーバー側、クライアント側とも、実行が await 式に到達するたびに動作が切り替えられ、効率的な並列処理が実行されるようになります。
こちらも期待通りに動作しています。
この async バージョンは複数のスレッドにまたがって実行されることもないので、debugger を利用してコードを追いかけることも容易です。
いかがだったでしょうか?
思っていたほど大変ではなかったのではないでしょうか?
今回利用していない機能は沢山ありますし、asyncio 関連は Python においてもまだまだ充実が図られている機能の一つです。
是非お互いがんばって勉強していきましょう!!
まとめ:
Python には言語自体に非同期実行 ( asynchronous execution ) プログラムを記述するための仕様がスマートに組み込まれています。
それによって、スレッド ( thread ) を利用してブロッキング I/O ( blocking I/O ) を処理しているコードを、スムーズにコルーチン ( coroutine ) と asyncronous I/O を利用した実装にリファクタリング ( refactoring ) することが出来るようになっているんです。
今回は、そういったリファクタリングの流れを、TCP ベースでサーバーとクライアント間でデータをやり取りしながら行う数当てゲーム ( guessing a number game ) を題材に見ていきたいと思います。
このゲームは、クライアントがサーバーに対して数の範囲を提示し、その範囲内で「答え」となる数をサーバーが当てる、というものです。
ゲームの流れは、
1: クライアントが出題
2: クライアントが「数 (当ててネ)」を要求
3: サーバーが出題された範囲内でランダムに数を生成して解答
4: クライアントが結果をサーバーに通知
5: サーバがレポートを出力 (クライアントが「隠し持っている正解」に前回の推測から近付いたのであれば 'WARMER'、遠ざかったのであれば 'COLDER')
となっていて、1回のセッション(ゲーム)の中で指定した回数だけ 2: - 5: が繰り返されます (正解が出ればその時点で終了)。
まぁ、このサーバーはクライアントからリクエストがある都度、提示された数の範囲内でランダムに数値を生成しそれを返すだけですので guessing game とは言い過ぎですが、あくまでも目的は「非同期プログラムへのリファクタリング」ですのでご了承ください。
さて、このようなクライアント/サーバー ( client / server ) システムを構築する最も一般的な方法は、ブロッキング I/O ( blocking I/O ) とスレッド ( thread ) を利用することでしょう。
まずは、クライアントとサーバー間でメッセージをやり取りするための機能を提供するヘルパークラスを定義しましょう。
やり取りするメッセージには処理内容を指示するコマンド文字列が含まれています。
class EOFError(Exception):
"""
接続が切断されたことを知らせる例外
"""
pass
class ConnectionBase:
"""
ソケットを利用したデータ送受信処理を実装する基底クラス
"""
def __init__(self, connection):
"""
:param connection: データ送受信の対象となるソケット
:attr file: 受信時にデータを読むためのソケットに紐付けたファイルオブジェクト
"""
self.connection = connection
self.file = connection.makefile('rb')
def send(self, command: str):
"""
データ送信
:param command: 送信データ
"""
line = command + '\n'
data = line.encode()
self.connection.send(data)
def receive(self):
"""
データ受信
クライアント側の接続ソケットが閉じられた場合に EOFError を発生させ、プログラムを終了させます。
接続ソケットが閉じられるのは、run_client() 内の create_connection() コンテキストブロックが終了した時です。
:return: 最後尾に付加されている改行文字を取り除いた受信文字列。
"""
line = self.file.readline()
if not line:
raise EOFError('== 接続が切断されました ==')
return line.rstrip().decode()
"""
接続が切断されたことを知らせる例外
"""
pass
class ConnectionBase:
"""
ソケットを利用したデータ送受信処理を実装する基底クラス
"""
def __init__(self, connection):
"""
:param connection: データ送受信の対象となるソケット
:attr file: 受信時にデータを読むためのソケットに紐付けたファイルオブジェクト
"""
self.connection = connection
self.file = connection.makefile('rb')
def send(self, command: str):
"""
データ送信
:param command: 送信データ
"""
line = command + '\n'
data = line.encode()
self.connection.send(data)
def receive(self):
"""
データ受信
クライアント側の接続ソケットが閉じられた場合に EOFError を発生させ、プログラムを終了させます。
接続ソケットが閉じられるのは、run_client() 内の create_connection() コンテキストブロックが終了した時です。
:return: 最後尾に付加されている改行文字を取り除いた受信文字列。
"""
line = self.file.readline()
if not line:
raise EOFError('== 接続が切断されました ==')
return line.rstrip().decode()
続けて、サーバーを 1 つのクラスとして定義します。
import random
# サーバーレポートに使用する「採点」基準。前回の推測と比較してどうだったか、を示す
WARMER = '近付いた'
COLDER = '遠ざかった'
UNSURE = '不明'
CORRECT = '正解!'
SAME = '前回と同じ'
class UnknownCommandError(Exception):
"""
意図しないコマンドを取得した場合に発生させる例外。
この例では、クライアント、サーバーとも人の手が介在しないため「ハッキング」されない限り発生することはない
"""
pass
class Session(ConnectionBase):
"""
サーバー側データ送受信処理実装クラス
"""
def __init__(self, *args):
super().__init__(*args)
self._clear_state(None, None)
def _clear_state(self, lower, upper):
self.lower = lower
self.upper = upper
self.guesses = []
def loop(self):
"""
クライアントからの接続要求待ち受けループ
"""
while command := self.receive():
parts = command.split(' ')
if parts[0] == 'PARAMS':
self.set_params(parts)
elif parts[0] == 'NUMBER':
self.send_number()
elif parts[0] == 'REPORT':
self.receive_report(parts)
else:
raise UnknownCommandError(command)
def set_params(self, parts):
"""
PARAMS コマンド処理関数
新たな数当てゲーム開始処理
:param parts: 'PARAM' '数当て範囲の最小値' '数当て範囲の最大値' の 3 要素からなるリスト
"""
assert len(parts) == 3
lower = int(parts[1])
upper = int(parts[2])
self._clear_state(lower, upper)
def next_guess(self):
"""
数当て範囲の中でランダム数 (推測した数) を生成
複数回の解答において同じ数は返さないので、数の範囲より回答数が等しいか大きければ必ず正答を出力できます。
:return: 推測数 ( guess number )
"""
while True:
guess = random.randint(self.lower, self.upper)
if guess not in self.guesses:
return guess
def send_number(self):
"""
NUMBER コマンド処理関数
推測数 ( guess number ) 生成、送信
"""
guess = self.next_guess()
self.guesses.append(guess)
self.send(format(guess))
def receive_report(self, parts):
"""
REPORT コマンド処理関数
解答受信、レポート出力処理
:param parts: 'REPORT' '「採点」を示す定数' の 2 要素からなるリスト
"""
assert len(parts) == 2
decision = parts[1]
last = self.guesses[-1]
print(f"Server: {last} is {decision}")
# サーバーレポートに使用する「採点」基準。前回の推測と比較してどうだったか、を示す
WARMER = '近付いた'
COLDER = '遠ざかった'
UNSURE = '不明'
CORRECT = '正解!'
SAME = '前回と同じ'
class UnknownCommandError(Exception):
"""
意図しないコマンドを取得した場合に発生させる例外。
この例では、クライアント、サーバーとも人の手が介在しないため「ハッキング」されない限り発生することはない
"""
pass
class Session(ConnectionBase):
"""
サーバー側データ送受信処理実装クラス
"""
def __init__(self, *args):
super().__init__(*args)
self._clear_state(None, None)
def _clear_state(self, lower, upper):
self.lower = lower
self.upper = upper
self.guesses = []
def loop(self):
"""
クライアントからの接続要求待ち受けループ
"""
while command := self.receive():
parts = command.split(' ')
if parts[0] == 'PARAMS':
self.set_params(parts)
elif parts[0] == 'NUMBER':
self.send_number()
elif parts[0] == 'REPORT':
self.receive_report(parts)
else:
raise UnknownCommandError(command)
def set_params(self, parts):
"""
PARAMS コマンド処理関数
新たな数当てゲーム開始処理
:param parts: 'PARAM' '数当て範囲の最小値' '数当て範囲の最大値' の 3 要素からなるリスト
"""
assert len(parts) == 3
lower = int(parts[1])
upper = int(parts[2])
self._clear_state(lower, upper)
def next_guess(self):
"""
数当て範囲の中でランダム数 (推測した数) を生成
複数回の解答において同じ数は返さないので、数の範囲より回答数が等しいか大きければ必ず正答を出力できます。
:return: 推測数 ( guess number )
"""
while True:
guess = random.randint(self.lower, self.upper)
if guess not in self.guesses:
return guess
def send_number(self):
"""
NUMBER コマンド処理関数
推測数 ( guess number ) 生成、送信
"""
guess = self.next_guess()
self.guesses.append(guess)
self.send(format(guess))
def receive_report(self, parts):
"""
REPORT コマンド処理関数
解答受信、レポート出力処理
:param parts: 'REPORT' '「採点」を示す定数' の 2 要素からなるリスト
"""
assert len(parts) == 2
decision = parts[1]
last = self.guesses[-1]
print(f"Server: {last} is {decision}")
クライアントも 1 つのクラスとして実装します。
import contextlib
import math
import time
class Client(ConnectionBase):
"""
クライアント側処理実装クラス
"""
def __init__(self, *args):
super().__init__(*args)
self._clear_state()
def _clear_state(self):
self.secret = None
self.last_distance = None
@contextlib.contextmanager
def session(self, lower, upper, secret):
"""
新たな数当てゲームセッションを開始します。
コンテキストマネージャとして with 文で利用されますが、yield 文では何も返していないので、もし with 文の as 節で受けても None が入るだけです。
yield 文までが __enter__() の実体として実行された後、with ブロックが実行され、終了時に __exit__() の実体として yield 文以降が実行されます。
:param lower: 数当て範囲の最小値
:param upper: 数当て範囲の最大値
:param secret: 正解
"""
print(f"{lower} と {upper} の間で数当てゲーム! -- 秘密だョ、正解は {secret} !")
self.secret = secret
self.send(f"PARAMS {lower} {upper}")
try:
yield
finally:
self._clear_state()
def request_numbers(self, count):
"""
推測数 ( guess number ) 要求ジェネレータ関数
:param count: 許可する推測回数
:return: サーバーが返答してきた推測数 ( guess number )
"""
for _ in range(count):
self.send('NUMBER')
data = self.receive()
yield int(data)
if self.last_distance == 0:
return
def report_outcome(self, number):
"""
サーバーが返答してきた推測数 ( guess number ) に対する「採点」返信処理
「採点 ( WARMER、COLDER 等)」は前回の返答に対しての相対比較であり、「解答」に対する絶対比較ではありません。
:param number: サーバーが返答してきた推測数 ( guess number )
:return: 採点 ('UNSURE', 'CORRECT', 'WARMER', 'COLDER', 'SAME' のいずれか)
"""
new_distance = math.fabs(number - self.secret)
decision = UNSURE
if new_distance == 0:
decision = CORRECT
elif self.last_distance is None:
pass
elif new_distance < self.last_distance:
decision = WARMER
elif new_distance > self.last_distance:
decision = COLDER
elif new_distance == self.last_distance:
decision = SAME
self.last_distance = new_distance
self.send(f"REPORT {decision}")
return decision
import math
import time
class Client(ConnectionBase):
"""
クライアント側処理実装クラス
"""
def __init__(self, *args):
super().__init__(*args)
self._clear_state()
def _clear_state(self):
self.secret = None
self.last_distance = None
@contextlib.contextmanager
def session(self, lower, upper, secret):
"""
新たな数当てゲームセッションを開始します。
コンテキストマネージャとして with 文で利用されますが、yield 文では何も返していないので、もし with 文の as 節で受けても None が入るだけです。
yield 文までが __enter__() の実体として実行された後、with ブロックが実行され、終了時に __exit__() の実体として yield 文以降が実行されます。
:param lower: 数当て範囲の最小値
:param upper: 数当て範囲の最大値
:param secret: 正解
"""
print(f"{lower} と {upper} の間で数当てゲーム! -- 秘密だョ、正解は {secret} !")
self.secret = secret
self.send(f"PARAMS {lower} {upper}")
try:
yield
finally:
self._clear_state()
def request_numbers(self, count):
"""
推測数 ( guess number ) 要求ジェネレータ関数
:param count: 許可する推測回数
:return: サーバーが返答してきた推測数 ( guess number )
"""
for _ in range(count):
self.send('NUMBER')
data = self.receive()
yield int(data)
if self.last_distance == 0:
return
def report_outcome(self, number):
"""
サーバーが返答してきた推測数 ( guess number ) に対する「採点」返信処理
「採点 ( WARMER、COLDER 等)」は前回の返答に対しての相対比較であり、「解答」に対する絶対比較ではありません。
:param number: サーバーが返答してきた推測数 ( guess number )
:return: 採点 ('UNSURE', 'CORRECT', 'WARMER', 'COLDER', 'SAME' のいずれか)
"""
new_distance = math.fabs(number - self.secret)
decision = UNSURE
if new_distance == 0:
decision = CORRECT
elif self.last_distance is None:
pass
elif new_distance < self.last_distance:
decision = WARMER
elif new_distance > self.last_distance:
decision = COLDER
elif new_distance == self.last_distance:
decision = SAME
self.last_distance = new_distance
self.send(f"REPORT {decision}")
return decision
このサーバーでは、クライアントからの接続要求に応える処理関数 ( handle_connection() ) をワーカースレッドで実行することで、複数のクライアントに応えることが可能になっています。
import socket
from threading import Thread
def handle_connection(connection):
"""
サーバー側データ送受信処理クラスを別スレッドで実行するための仲介関数
:param connection: 特定のアドレスに bind され、linsten(), accept() によって接続受け付け可能になっている接続ソケット
"""
with connection:
session = Session(connection)
try:
session.loop()
except EOFError as e:
print(e)
def run_server(address):
"""
サーバー側接続待ち受け処理
:param address: (ホスト, ポート) から成るタプル
"""
with socket.socket() as listener:
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # 同じポートを再利用できるようにオプションを設定します
listener.bind(address)
listener.listen()
while True:
connection, _ = listener.accept()
thread = Thread(target=handle_connection, args=(connection,), daemon=True) # 実際のデータ送受信処理は別スレッドで行います
thread.start()
# Python 3.8 からは create_server() でサーバー側のソケットオブジェクトを取得できるようになりました
#
# with socket.create_server(address) as listener:
# while True:
# connection, _ = listener.accept()
# thread = Thread(target=handle_connection, args=(connection,), daemon=True)
# thread.start()
from threading import Thread
def handle_connection(connection):
"""
サーバー側データ送受信処理クラスを別スレッドで実行するための仲介関数
:param connection: 特定のアドレスに bind され、linsten(), accept() によって接続受け付け可能になっている接続ソケット
"""
with connection:
session = Session(connection)
try:
session.loop()
except EOFError as e:
print(e)
def run_server(address):
"""
サーバー側接続待ち受け処理
:param address: (ホスト, ポート) から成るタプル
"""
with socket.socket() as listener:
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # 同じポートを再利用できるようにオプションを設定します
listener.bind(address)
listener.listen()
while True:
connection, _ = listener.accept()
thread = Thread(target=handle_connection, args=(connection,), daemon=True) # 実際のデータ送受信処理は別スレッドで行います
thread.start()
# Python 3.8 からは create_server() でサーバー側のソケットオブジェクトを取得できるようになりました
#
# with socket.create_server(address) as listener:
# while True:
# connection, _ = listener.accept()
# thread = Thread(target=handle_connection, args=(connection,), daemon=True)
# thread.start()
クライアントはメインスレッドで動作します。
def run_client(address):
"""
サーバーが listen しているサービスに接続するクライアント
:param address: (ホスト, ポート) から成るタプル
"""
with socket.create_connection(address) as connection:
client = Client(connection)
with client.session(1, 5, 3):
results = [(x, client.report_outcome(x)) for x in client.request_numbers(5)]
time.sleep(0.05)
with client.session(10, 15, 12):
for number in client.request_numbers(6):
outcome = client.report_outcome(number)
results.append((number, outcome))
time.sleep(0.05)
return results
"""
サーバーが listen しているサービスに接続するクライアント
:param address: (ホスト, ポート) から成るタプル
"""
with socket.create_connection(address) as connection:
client = Client(connection)
with client.session(1, 5, 3):
results = [(x, client.report_outcome(x)) for x in client.request_numbers(5)]
time.sleep(0.05)
with client.session(10, 15, 12):
for number in client.request_numbers(6):
outcome = client.report_outcome(number)
results.append((number, outcome))
time.sleep(0.05)
return results
さぁ、準備は出来ました。
サーバーを稼動し、クライアントを立ち上げて実行してみましょう!
def main():
address = ('127.0.0.1', 1234) # この例題では、同じマシン内にサーバーとクライアントを作成しやり取りする
server_thread = Thread(target=run_server, args=(address,), daemon=True)
server_thread.start() # サーバーを別スレッドで実行
results = run_client(address)
print()
for number, outcome in results:
print(f"Client: {number} is {outcome}")
main()
# 1 と 5 の間で数当てゲーム! -- 秘密だョ、正解は 3 !
# Server: 2 is 不明
# Server: 3 is 正解!
# 10 と 15 の間で数当てゲーム! -- 秘密だョ、正解は 12 !
# Server: 14 is 不明
# Server: 15 is 遠ざかった
# Server: 12 is 正解!
# == 接続が切断されました ==
#
# Client: 2 is 不明
# Client: 3 is 正解!
# Client: 14 is 不明
# Client: 15 is 遠ざかった
# Client: 12 is 正解!
address = ('127.0.0.1', 1234) # この例題では、同じマシン内にサーバーとクライアントを作成しやり取りする
server_thread = Thread(target=run_server, args=(address,), daemon=True)
server_thread.start() # サーバーを別スレッドで実行
results = run_client(address)
print()
for number, outcome in results:
print(f"Client: {number} is {outcome}")
main()
# 1 と 5 の間で数当てゲーム! -- 秘密だョ、正解は 3 !
# Server: 2 is 不明
# Server: 3 is 正解!
# 10 と 15 の間で数当てゲーム! -- 秘密だョ、正解は 12 !
# Server: 14 is 不明
# Server: 15 is 遠ざかった
# Server: 12 is 正解!
# == 接続が切断されました ==
#
# Client: 2 is 不明
# Client: 3 is 正解!
# Client: 14 is 不明
# Client: 15 is 遠ざかった
# Client: 12 is 正解!
期待通りに実行されていますね!
でも安心している場合じゃぁありません、ここからが今回の記事の本題です。
ここまでの例はブロッキング I/O とスレッドを利用した「一般的」な実装です。
果たしてこの「既存」のコードを async、await、asyncio を利用して、コルーチン ( coroutine ) と asyncronous I/O による非同期実行プログラムにリファクタリングするにはどれ位の手間がかかってしまうのでしょうか?
ではやっていきましょう!
どの部分を変更したのかが判断しやすいように、元のコードに手を加えた部分には「追加」、「変更」等のコメントを入れておきます。
まずは ConnectionBase クラスで提供していた I/O 処理メソッド ( send() と receive() ) をコルーチンに書き換え、ブロッキング I/O メソッドの代わりに、接続ソケットに対して非同期読み込み、書き込みを実行可能な StreamReader クラスインスタンス、StreamWriter クラスインスタンスを利用するようにします。
クラス名もそれっぽく AsyncConnectionBase と変更しましょう。
import asyncio # 追加
class EOFError(Exception):
pass
class AsyncConnectionBase:
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): # 変更
self.reader = reader # 変更
self.writer = writer # 変更
async def send(self, command: str):
"""
StreamWriter クラスインスタンスによる書き込み時には、write() と drain() をセットで利用するように習慣付けましょう。
writer.write(data)
await writer.drain()
で 1 セットです。
"""
line = command + '\n'
data = line.encode()
self.writer.write(data) # 変更
await self.writer.drain() # 変更
async def receive(self):
line = await self.reader.readline() # 変更
if not line:
self.writer.close() # 追加
raise EOFError('== 接続が切断されました ==')
return line.rstrip().decode()
import random
WARMER = '近付いた'
COLDER = '遠ざかった'
UNSURE = '不明'
CORRECT = '正解!'
SAME = '前回と同じ'
class UnknownCommandError(Exception):
pass
class EOFError(Exception):
pass
class AsyncConnectionBase:
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): # 変更
self.reader = reader # 変更
self.writer = writer # 変更
async def send(self, command: str):
"""
StreamWriter クラスインスタンスによる書き込み時には、write() と drain() をセットで利用するように習慣付けましょう。
writer.write(data)
await writer.drain()
で 1 セットです。
"""
line = command + '\n'
data = line.encode()
self.writer.write(data) # 変更
await self.writer.drain() # 変更
async def receive(self):
line = await self.reader.readline() # 変更
if not line:
self.writer.close() # 追加
raise EOFError('== 接続が切断されました ==')
return line.rstrip().decode()
import random
WARMER = '近付いた'
COLDER = '遠ざかった'
UNSURE = '不明'
CORRECT = '正解!'
SAME = '前回と同じ'
class UnknownCommandError(Exception):
pass
サーバー側クラスが AsyncConnectionBase クラスから派生するようにします。
そして、実際に I/O 処理に携わっているクラスメソッドを async def キーワードでコルーチンに変更し、send()、receive() を呼び出している箇所を await 式に変更しましょう。
実はサーバー部分の変更はこれだけなんです!
class AsyncSession(AsyncConnectionBase): # 変更
def __init__(self, *args):
super().__init__(*args)
self._clear_state(None, None)
def _clear_state(self, lower, upper):
self.lower = lower
self.upper = upper
self.guesses = []
async def loop(self): # 変更
while command := await self.receive(): # 変更
parts = command.split(' ')
if parts[0] == 'PARAMS':
self.set_params(parts)
elif parts[0] == 'NUMBER':
await self.send_number() # 変更
elif parts[0] == 'REPORT':
self.receive_report(parts)
else:
raise UnknownCommandError(command)
def set_params(self, parts):
assert len(parts) == 3
lower = int(parts[1])
upper = int(parts[2])
self._clear_state(lower, upper)
def next_guess(self):
while True:
guess = random.randint(self.lower, self.upper)
if guess not in self.guesses:
return guess
async def send_number(self): # 変更
guess = self.next_guess()
self.guesses.append(guess)
await self.send(format(guess)) # 変更
def receive_report(self, parts):
assert len(parts) == 2
decision = parts[1]
last = self.guesses[-1]
print(f"Server: {last} is {decision}")
def __init__(self, *args):
super().__init__(*args)
self._clear_state(None, None)
def _clear_state(self, lower, upper):
self.lower = lower
self.upper = upper
self.guesses = []
async def loop(self): # 変更
while command := await self.receive(): # 変更
parts = command.split(' ')
if parts[0] == 'PARAMS':
self.set_params(parts)
elif parts[0] == 'NUMBER':
await self.send_number() # 変更
elif parts[0] == 'REPORT':
self.receive_report(parts)
else:
raise UnknownCommandError(command)
def set_params(self, parts):
assert len(parts) == 3
lower = int(parts[1])
upper = int(parts[2])
self._clear_state(lower, upper)
def next_guess(self):
while True:
guess = random.randint(self.lower, self.upper)
if guess not in self.guesses:
return guess
async def send_number(self): # 変更
guess = self.next_guess()
self.guesses.append(guess)
await self.send(format(guess)) # 変更
def receive_report(self, parts):
assert len(parts) == 2
decision = parts[1]
last = self.guesses[-1]
print(f"Server: {last} is {decision}")
クライアント側クラスも AsyncConnectionBase クラスから派生するようにします。
続けて先程のサーバークラスと同様に、実際に I/O 処理に携わっているクラスメソッドを async def キーワードでコルーチンに変更し、send()、receive() を呼び出している箇所を await 式に変更しましょう。
また、この変更に伴って、contextlib.contextmanager() を利用してデコレートしている session() は、contextlib.asynccontextmanager() でデコレートするように変更します。
import contextlib
import math
class AsyncClient(AsyncConnectionBase): # 変更
def __init__(self, *args):
super().__init__(*args)
self._clear_state()
def _clear_state(self):
self.secret = None
self.last_distance = None
@contextlib.asynccontextmanager # 変更
async def session(self, lower, upper, secret): # 変更
print(f"{lower} と {upper} の間で数当てゲーム! -- 秘密だョ、正解は {secret} !")
self.secret = secret
await self.send(f"PARAMS {lower} {upper}") # 変更
try:
yield # この関数を非同期ジェネレータ ( asynchronous generator ) として定義します
finally:
self._clear_state()
async def request_numbers(self, count): # 変更
for _ in range(count):
await self.send('NUMBER') # 変更
data = await self.receive() # 変更
yield int(data)
if self.last_distance == 0:
return
async def report_outcome(self, number): # 変更
new_distance = math.fabs(number - self.secret)
decision = UNSURE
if new_distance == 0:
decision = CORRECT
elif self.last_distance is None:
pass
elif new_distance < self.last_distance:
decision = WARMER
elif new_distance > self.last_distance:
decision = COLDER
elif new_distance == self.last_distance:
decision = SAME
self.last_distance = new_distance
await self.send(f"REPORT {decision}") # 変更
await asyncio.sleep(0.01) # 追加: スレッドを使用した実装と出力順番が同じになるようにします
return decision
import math
class AsyncClient(AsyncConnectionBase): # 変更
def __init__(self, *args):
super().__init__(*args)
self._clear_state()
def _clear_state(self):
self.secret = None
self.last_distance = None
@contextlib.asynccontextmanager # 変更
async def session(self, lower, upper, secret): # 変更
print(f"{lower} と {upper} の間で数当てゲーム! -- 秘密だョ、正解は {secret} !")
self.secret = secret
await self.send(f"PARAMS {lower} {upper}") # 変更
try:
yield # この関数を非同期ジェネレータ ( asynchronous generator ) として定義します
finally:
self._clear_state()
async def request_numbers(self, count): # 変更
for _ in range(count):
await self.send('NUMBER') # 変更
data = await self.receive() # 変更
yield int(data)
if self.last_distance == 0:
return
async def report_outcome(self, number): # 変更
new_distance = math.fabs(number - self.secret)
decision = UNSURE
if new_distance == 0:
decision = CORRECT
elif self.last_distance is None:
pass
elif new_distance < self.last_distance:
decision = WARMER
elif new_distance > self.last_distance:
decision = COLDER
elif new_distance == self.last_distance:
decision = SAME
self.last_distance = new_distance
await self.send(f"REPORT {decision}") # 変更
await asyncio.sleep(0.01) # 追加: スレッドを使用した実装と出力順番が同じになるようにします
return decision
サーバーを開始する関数は asyncio 組み込みモジュールの start_server() を利用して完全に書き換える必要があります。
import socket # 不要になりました
from threading import Thread # 不要になりました
async def handle_async_connection(reader, writer): # 新規
"""
クライアントから接続要求があるたびに呼び出され、データの送受信を開始します
:param reader: クライアントからのデータを受信する StreamReader クラスインスタンス
:param writer: クライアントへデータを送信する StreamWriter クラスインスタンス
"""
session = AsyncSession(reader, writer)
try:
await session.loop()
except EOFError as e:
print(e)
async def run_async_server(address): # 新規
"""
サーバーを開始します
:param address: (ホスト名、ポート番号) からなるタプル
"""
async with (server := await asyncio.start_server(handle_async_connection, *address)):
await server.serve_forever()
from threading import Thread # 不要になりました
async def handle_async_connection(reader, writer): # 新規
"""
クライアントから接続要求があるたびに呼び出され、データの送受信を開始します
:param reader: クライアントからのデータを受信する StreamReader クラスインスタンス
:param writer: クライアントへデータを送信する StreamWriter クラスインスタンス
"""
session = AsyncSession(reader, writer)
try:
await session.loop()
except EOFError as e:
print(e)
async def run_async_server(address): # 新規
"""
サーバーを開始します
:param address: (ホスト名、ポート番号) からなるタプル
"""
async with (server := await asyncio.start_server(handle_async_connection, *address)):
await server.serve_forever()
クライアントを実行する関数も大幅に変更する必要があります。
これは、ブロッキング socket インスタンスを利用していた箇所を対応する asyncio バージョンのものにしなければなりませんし、コルーチンを呼び出す箇所には全て await キーワードをつける必要がありますし、for や with も async バージョンに変更する必要がありますから仕方ありませんね。
また、asyncio.open_connection() を with 文で利用することが出来ないため、クライアントの処理が終了した時点で StreamWriter インスタンスを閉じることでサーバーに接続終了を通知しています。
async def run_async_client(address): # 変更
await asyncio.sleep(0.1) # サーバーの準備ができるのをちょっと待ちます
streams = await asyncio.open_connection(*address) # 変更
client = AsyncClient(*streams) # 変更
async with client.session(1, 5, 3):
results = [(x, await client.report_outcome(x)) async for x in client.request_numbers(5)] # 変更
await asyncio.sleep(0.05)
async with client.session(10, 15, 12):
async for number in client.request_numbers(6): # 変更
outcome = await client.report_outcome(number) # 変更
results.append((number, outcome))
await asyncio.sleep(0.05)
_, writer = streams # 新規
writer.close() # 新規
await writer.wait_closed() # 新規
return results
await asyncio.sleep(0.1) # サーバーの準備ができるのをちょっと待ちます
streams = await asyncio.open_connection(*address) # 変更
client = AsyncClient(*streams) # 変更
async with client.session(1, 5, 3):
results = [(x, await client.report_outcome(x)) async for x in client.request_numbers(5)] # 変更
await asyncio.sleep(0.05)
async with client.session(10, 15, 12):
async for number in client.request_numbers(6): # 変更
outcome = await client.report_outcome(number) # 変更
results.append((number, outcome))
await asyncio.sleep(0.05)
_, writer = streams # 新規
writer.close() # 新規
await writer.wait_closed() # 新規
return results
最後に、このプログラムのエントリーポイントである関数に変更を加えましょう。
この関数もコルーチンに変更し、サーバー開始コルーチンを asyncio.create_task() を利用してイベントループで実行されるようにスケジューリングします。
これによって、サーバー側、クライアント側とも、実行が await 式に到達するたびに動作が切り替えられ、効率的な並列処理が実行されるようになります。
async def main_async(): # 変更
address = ('127.0.0.1', 4321)
server = run_async_server(address) # 変更
asyncio.create_task(server) # 変更
results = await run_async_client(address) # 変更
print()
for number, outcome in results:
print(f"Client: {number} is {outcome}")
asyncio.run(main_async()) # 変更
# 1 と 5 の間で数当てゲーム! -- 秘密だョ、正解は 3 !
# Server: 1 is 不明
# Server: 2 is 近付いた
# Server: 5 is 遠ざかった
# Server: 4 is 近付いた
# Server: 3 is 正解!
# 10 と 15 の間で数当てゲーム! -- 秘密だョ、正解は 12 !
# Server: 14 is 不明
# Server: 10 is 前回と同じ
# Server: 12 is 正解!
#
# Client: 1 is 不明
# Client: 2 is 近付いた
# Client: 5 is 遠ざかった
# Client: 4 is 近付いた
# Client: 3 is 正解!
# Client: 14 is 不明
# Client: 10 is 前回と同じ
# Client: 12 is 正解!
# == 接続が切断されました ==
address = ('127.0.0.1', 4321)
server = run_async_server(address) # 変更
asyncio.create_task(server) # 変更
results = await run_async_client(address) # 変更
print()
for number, outcome in results:
print(f"Client: {number} is {outcome}")
asyncio.run(main_async()) # 変更
# 1 と 5 の間で数当てゲーム! -- 秘密だョ、正解は 3 !
# Server: 1 is 不明
# Server: 2 is 近付いた
# Server: 5 is 遠ざかった
# Server: 4 is 近付いた
# Server: 3 is 正解!
# 10 と 15 の間で数当てゲーム! -- 秘密だョ、正解は 12 !
# Server: 14 is 不明
# Server: 10 is 前回と同じ
# Server: 12 is 正解!
#
# Client: 1 is 不明
# Client: 2 is 近付いた
# Client: 5 is 遠ざかった
# Client: 4 is 近付いた
# Client: 3 is 正解!
# Client: 14 is 不明
# Client: 10 is 前回と同じ
# Client: 12 is 正解!
# == 接続が切断されました ==
こちらも期待通りに動作しています。
この async バージョンは複数のスレッドにまたがって実行されることもないので、debugger を利用してコードを追いかけることも容易です。
いかがだったでしょうか?
思っていたほど大変ではなかったのではないでしょうか?
今回利用していない機能は沢山ありますし、asyncio 関連は Python においてもまだまだ充実が図られている機能の一つです。
是非お互いがんばって勉強していきましょう!!
まとめ:
1: Python では既存のプログラムをコルーチンに書き換えることを容易にする従来構文の async バージョンが数多く用意されているほか、ヘルパー関数群も充実しています。
2: 従来ブロッキング I/O とスレッドを利用して記述していたコードを、asyncio 組み込みモジュールを利用して、コルーチンと非同期 I/O ( asynchronous I/O ) を利用したコードに最小限の労力でリファクタリングすることが可能です。
この記事に興味のある方は次の記事にも関心を持っているようです...
- People who read this article may also be interested in following articles ... -