import grpc from concurrent import futures import logging from sqlalchemy import text from sqlalchemy.exc import IntegrityError from proto import Book_pb2_grpc, Book_pb2 from google.protobuf.json_format import MessageToDict from database import Books, get_db_session, Paginator import psutil import time import argparse def performance_test(response_object): def decorator(func): def wrapper(self, request, context): response_data = func(self, request, context) test_return = response_object( response_data=response_data ) process = psutil.Process() # 1. parse request cpu_count = psutil.cpu_count() process.cpu_percent() # cpu_start = process.cpu_times() parse_start = time.perf_counter() request_deserializer = request.__class__.FromString request_deserializer(request.SerializeToString()) parse_end = time.perf_counter() cpu_percent_mid = process.cpu_percent() # cpu_mid = process.cpu_times() # 2. serialize response serialize_start = time.perf_counter() process.cpu_percent() # cpu_serialize_start = process.cpu_times() response_serializer = response_object.SerializeToString response_serializer(test_return) serialize_end = time.perf_counter() cpu_end_percent = process.cpu_percent() # cpu_end = process.cpu_times() parse_total_time = parse_end - parse_start # parse_cpu_time = (cpu_mid.user - cpu_start.user) + (cpu_mid.system - cpu_start.system) serialize_total_time = serialize_end - serialize_start # serialize_cpu_time = (cpu_end.user - cpu_serialize_start.user) + (cpu_end.system - cpu_serialize_start.system) timing_data = { 'server_deserialize': { 'time': parse_total_time, 'cpu': cpu_percent_mid / cpu_count # / parse_total_time * 100 }, 'server_serialize': { 'time': serialize_total_time, 'cpu': cpu_end_percent / cpu_count # / serialize_total_time * 100 }, 'server_protocol_total_time': parse_end - parse_start + serialize_end - serialize_start } return response_object( response_data=response_data, **timing_data ) return wrapper return decorator class BookServiceServicer(Book_pb2_grpc.BookServiceServicer): @performance_test(Book_pb2.GetListResponse) def GetList(self, request, context): # print("request", request) pages = request.pages per_page = request.per_page list_data_limit = request.list_data_limit if list_data_limit <= 0: list_data_limit = per_page with get_db_session() as session: paginator = Paginator(session) query = session.query(Books).order_by(Books.id) data = paginator.paginate(query, page=pages, per_page=min(per_page, list_data_limit), error_out=False) books = [book.to_dict() for book in data.items] if len(books) < per_page: books = books * (per_page // len(books)) + books[:per_page % len(books)] return books @performance_test(Book_pb2.GeneralResponse) def AddBooks(self, request, context): books = request.books test_only = request.test_only if test_only: return Book_pb2.MessageResponse(message="Books added successfully") with get_db_session() as session: try: for book in books: book_dict = MessageToDict(book, preserving_proto_field_name=True) if "id" in book_dict: book_dict.pop("id") new_book = Books(**book_dict) session.add(new_book) session.commit() except IntegrityError as e: print(f"Unique constraint violation: {e}, trying to reset sequence") session.rollback() session.execute(text("SELECT setval('books_id_seq', (SELECT COALESCE(MAX(id), 0) FROM books))")) session.commit() for book in books: book_dict = MessageToDict(book, preserving_proto_field_name=True) if "id" in book_dict: book_dict.pop("id") new_book = Books(**book_dict) session.add(new_book) session.commit() return Book_pb2.MessageResponse(message="Books added successfully") @performance_test(Book_pb2.GeneralResponse) def DeleteBooks(self, request, context): book_ids = request.book_ids delete_last_count = request.delete_last_count with get_db_session() as session: if delete_last_count > 0: last_ids = session.query(Books.id).order_by(Books.id.desc()).limit(delete_last_count).subquery() session.query(Books).filter(Books.id.in_( session.query(last_ids.c.id) )).delete(synchronize_session=False) else: session.query(Books).filter(Books.id.in_(book_ids)).delete(synchronize_session=False) session.commit() return Book_pb2.MessageResponse(message="Books deleted successfully") @performance_test(Book_pb2.GeneralResponse) def UpdateBook(self, request, context): book_data = request.book with get_db_session() as session: book = session.query(Books).filter(Books.id == book_data.id).first() if book: for key, value in book_data.items(): setattr(book, key, value) session.commit() return Book_pb2.MessageResponse(message="Book updated successfully") else: context.set_code(grpc.StatusCode.NOT_FOUND) context.set_details("Book not found") return Book_pb2.MessageResponse(message="Book not found") def Ping(self, request, context): return Book_pb2.MessageResponse(message="Pong!") def serve(port=50051): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), options=[ ('grpc.max_receive_message_length', -1), ('grpc.max_send_message_length', -1), ]) Book_pb2_grpc.add_BookServiceServicer_to_server( BookServiceServicer(), server ) listen_addr = f'[::]:{port}' server.add_insecure_port(listen_addr) server.start() print(f"gRPC server runs on: {listen_addr}") try: server.wait_for_termination() except KeyboardInterrupt: server.stop(0) if __name__ == '__main__': arg_parse = argparse.ArgumentParser(description="Run REST server") arg_parse.add_argument("--port", "-p", type=int, default=50051, help="Port to run the server on") args = arg_parse.parse_args() logging.basicConfig(level=logging.INFO) serve(args.port)