194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
import grpc
|
|
from concurrent import futures
|
|
import logging
|
|
from sqlalchemy import text
|
|
from sqlalchemy.exc import IntegrityError
|
|
from proto import Book_pb2_grpc, Book_pb2
|
|
from google.protobuf.json_format import MessageToDict
|
|
from database import Books, get_db_session, Paginator
|
|
import psutil
|
|
import time
|
|
import argparse
|
|
|
|
def performance_test(response_object):
|
|
def decorator(func):
|
|
def wrapper(self, request, context):
|
|
response_data = func(self, request, context)
|
|
test_return = response_object(
|
|
response_data=response_data
|
|
)
|
|
|
|
process = psutil.Process()
|
|
# 1. parse request
|
|
cpu_count = psutil.cpu_count()
|
|
|
|
process.cpu_percent()
|
|
# cpu_start = process.cpu_times()
|
|
parse_start = time.perf_counter()
|
|
|
|
request_deserializer = request.__class__.FromString
|
|
request_deserializer(request.SerializeToString())
|
|
|
|
parse_end = time.perf_counter()
|
|
cpu_percent_mid = process.cpu_percent()
|
|
# cpu_mid = process.cpu_times()
|
|
|
|
# 2. serialize response
|
|
serialize_start = time.perf_counter()
|
|
process.cpu_percent()
|
|
# cpu_serialize_start = process.cpu_times()
|
|
|
|
response_serializer = response_object.SerializeToString
|
|
response_serializer(test_return)
|
|
|
|
serialize_end = time.perf_counter()
|
|
cpu_end_percent = process.cpu_percent()
|
|
# cpu_end = process.cpu_times()
|
|
|
|
parse_total_time = parse_end - parse_start
|
|
# parse_cpu_time = (cpu_mid.user - cpu_start.user) + (cpu_mid.system - cpu_start.system)
|
|
serialize_total_time = serialize_end - serialize_start
|
|
# serialize_cpu_time = (cpu_end.user - cpu_serialize_start.user) + (cpu_end.system - cpu_serialize_start.system)
|
|
|
|
timing_data = {
|
|
'server_deserialize': {
|
|
'time': parse_total_time,
|
|
'cpu': cpu_percent_mid / cpu_count # / parse_total_time * 100
|
|
},
|
|
'server_serialize': {
|
|
'time': serialize_total_time,
|
|
'cpu': cpu_end_percent / cpu_count # / serialize_total_time * 100
|
|
},
|
|
'server_protocol_total_time': parse_end - parse_start + serialize_end - serialize_start
|
|
}
|
|
|
|
return response_object(
|
|
response_data=response_data,
|
|
**timing_data
|
|
)
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
class BookServiceServicer(Book_pb2_grpc.BookServiceServicer):
|
|
|
|
@performance_test(Book_pb2.GetListResponse)
|
|
def GetList(self, request, context):
|
|
# print("request", request)
|
|
pages = request.pages
|
|
per_page = request.per_page
|
|
list_data_limit = request.list_data_limit
|
|
if list_data_limit <= 0:
|
|
list_data_limit = per_page
|
|
|
|
with get_db_session() as session:
|
|
paginator = Paginator(session)
|
|
query = session.query(Books).order_by(Books.id)
|
|
data = paginator.paginate(query, page=pages, per_page=min(per_page, list_data_limit), error_out=False)
|
|
books = [book.to_dict() for book in data.items]
|
|
if len(books) < per_page:
|
|
books = books * (per_page // len(books)) + books[:per_page % len(books)]
|
|
|
|
return books
|
|
|
|
@performance_test(Book_pb2.GeneralResponse)
|
|
def AddBooks(self, request, context):
|
|
books = request.books
|
|
test_only = request.test_only
|
|
if test_only:
|
|
return Book_pb2.MessageResponse(message="Books added successfully")
|
|
|
|
with get_db_session() as session:
|
|
try:
|
|
for book in books:
|
|
book_dict = MessageToDict(book, preserving_proto_field_name=True)
|
|
if "id" in book_dict:
|
|
book_dict.pop("id")
|
|
new_book = Books(**book_dict)
|
|
session.add(new_book)
|
|
session.commit()
|
|
except IntegrityError as e:
|
|
print(f"Unique constraint violation: {e}, trying to reset sequence")
|
|
session.rollback()
|
|
session.execute(text("SELECT setval('books_id_seq', (SELECT COALESCE(MAX(id), 0) FROM books))"))
|
|
session.commit()
|
|
|
|
for book in books:
|
|
book_dict = MessageToDict(book, preserving_proto_field_name=True)
|
|
if "id" in book_dict:
|
|
book_dict.pop("id")
|
|
new_book = Books(**book_dict)
|
|
session.add(new_book)
|
|
session.commit()
|
|
|
|
return Book_pb2.MessageResponse(message="Books added successfully")
|
|
|
|
@performance_test(Book_pb2.GeneralResponse)
|
|
def DeleteBooks(self, request, context):
|
|
book_ids = request.book_ids
|
|
delete_last_count = request.delete_last_count
|
|
|
|
with get_db_session() as session:
|
|
if delete_last_count > 0:
|
|
last_ids = session.query(Books.id).order_by(Books.id.desc()).limit(delete_last_count).subquery()
|
|
session.query(Books).filter(Books.id.in_(
|
|
session.query(last_ids.c.id)
|
|
)).delete(synchronize_session=False)
|
|
else:
|
|
session.query(Books).filter(Books.id.in_(book_ids)).delete(synchronize_session=False)
|
|
|
|
session.commit()
|
|
|
|
return Book_pb2.MessageResponse(message="Books deleted successfully")
|
|
|
|
@performance_test(Book_pb2.GeneralResponse)
|
|
def UpdateBook(self, request, context):
|
|
book_data = request.book
|
|
with get_db_session() as session:
|
|
book = session.query(Books).filter(Books.id == book_data.id).first()
|
|
if book:
|
|
for key, value in book_data.items():
|
|
setattr(book, key, value)
|
|
session.commit()
|
|
return Book_pb2.MessageResponse(message="Book updated successfully")
|
|
else:
|
|
context.set_code(grpc.StatusCode.NOT_FOUND)
|
|
context.set_details("Book not found")
|
|
return Book_pb2.MessageResponse(message="Book not found")
|
|
|
|
def Ping(self, request, context):
|
|
return Book_pb2.MessageResponse(message="Pong!")
|
|
|
|
|
|
def serve(port=50051):
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
|
|
options=[
|
|
('grpc.max_receive_message_length', -1),
|
|
('grpc.max_send_message_length', -1),
|
|
])
|
|
|
|
Book_pb2_grpc.add_BookServiceServicer_to_server(
|
|
BookServiceServicer(), server
|
|
)
|
|
|
|
listen_addr = f'[::]:{port}'
|
|
server.add_insecure_port(listen_addr)
|
|
|
|
server.start()
|
|
print(f"gRPC server runs on: {listen_addr}")
|
|
|
|
try:
|
|
server.wait_for_termination()
|
|
except KeyboardInterrupt:
|
|
server.stop(0)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
arg_parse = argparse.ArgumentParser(description="Run REST server")
|
|
arg_parse.add_argument("--port", "-p", type=int, default=50051, help="Port to run the server on")
|
|
|
|
args = arg_parse.parse_args()
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
serve(args.port)
|