2025-08-20 11:54:40 +01:00

222 lines
6.9 KiB
Python

from sqlalchemy import create_engine, select, func, Column, Integer, String, Text, Float, BigInteger
from sqlalchemy.orm import sessionmaker, Session, declarative_base
from sqlalchemy.pool import QueuePool
import sqlalchemy
from datetime import datetime
from tqdm import tqdm
from enum import Enum
from faker import Faker
import random
import os
from contextlib import contextmanager
from typing import List, Optional, Any
from dataclasses import dataclass
engine = create_engine(
# 'sqlite:///rest_server_books.db',
'postgresql://postgres:liuyanfeng66@localhost:5432/BookRestGrpcCompare',
poolclass=QueuePool,
pool_size=30,
max_overflow=50,
pool_pre_ping=True,
echo=False,
# connect_args={
# "check_same_thread": False,
# "timeout": 30 # 文件锁最长等待秒数,视情况可调
# }
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
@dataclass
class PaginationResult:
"""分页结果数据类"""
items: List[Any]
total: int
page: int
per_page: int
pages: int
has_prev: bool
has_next: bool
prev_num: Optional[int]
next_num: Optional[int]
class Paginator:
def __init__(self, session: Session):
self.session = session
def paginate(self, query, page: int = 1, per_page: int = 20, error_out: bool = True) -> PaginationResult:
if per_page <= 0:
return PaginationResult(
items=[],
total=0,
page=1,
per_page=0,
pages=0,
has_prev=False,
has_next=False,
prev_num=None,
next_num=None
)
page = max(1, page)
per_page = max(1, per_page)
# Get total count of items
count_query = select(func.count()).select_from(query.statement.alias())
total = self.session.scalar(count_query)
# Get total pages
pages = (total + per_page - 1) // per_page
if error_out and page > pages and total > 0:
raise ValueError(f"Page {page} out of range (1-{pages})")
offset = (page - 1) * per_page
items = query.offset(offset).limit(per_page).all()
return PaginationResult(
items=items,
total=total,
page=page,
per_page=per_page,
pages=pages,
has_prev=page > 1,
has_next=page < pages,
prev_num=page - 1 if page > 1 else None,
next_num=page + 1 if page < pages else None
)
class BaseModel(Base):
__abstract__ = True
def to_dict(self, exclude: set = None):
exclude = exclude or set()
ret = {}
for c in self.__table__.columns:
if c.name not in exclude:
curr_item = getattr(self, c.name)
if isinstance(curr_item, Enum):
ret[c.name] = curr_item.value
else:
ret[c.name] = curr_item
return ret
class Books(BaseModel):
__tablename__ = 'books'
id = Column(BigInteger, primary_key=True, autoincrement=True)
isbn: str = Column(String(20), nullable=True, index=True)
barcode: str = Column(String(50), nullable=True)
title: str = Column(Text, nullable=False)
subtitle: str = Column(Text, nullable=True)
author: str = Column(Text, nullable=False)
translator: str = Column(Text, nullable=True)
editor: str = Column(Text, nullable=True)
publisher: str = Column(String(200), nullable=True)
publication_date = Column(BigInteger, nullable=True)
edition: str = Column(String(50), nullable=True)
pages: int = Column(Integer, nullable=True)
language: str = Column(String(50), nullable=True)
category_id: int = Column(Integer, nullable=True, index=True)
subject: str = Column(String(200), nullable=True)
keywords: str = Column(Text, nullable=True)
description: str = Column(Text, nullable=True)
abstract: str = Column(Text, nullable=True)
format: str = Column(String(50), nullable=True)
binding: str = Column(String(20), nullable=True)
weight: float = Column(Float, nullable=True)
cover_image: str = Column(String(500), nullable=True)
def generate_random_book_data(book_id: int, fake: Faker) -> dict:
return {
'id': book_id,
'isbn': fake.isbn13(separator='-'),
'barcode': fake.ean(length=13),
'title': fake.sentence(nb_words=6),
'subtitle': fake.sentence(nb_words=3),
'author': fake.name(),
'translator': fake.name(),
'editor': fake.name(),
'publisher': fake.company(),
'publication_date': random.randint(-2190472676, 1754206067),
'edition': fake.word(),
'pages': random.randint(100, 1000),
'language': random.choice(['English', 'Français', 'Deutsch']),
'category_id': random.randint(1, 100),
'subject': fake.word(),
'keywords': ", ".join(fake.words(nb=5, unique=True, ext_word_list=None)),
'description': fake.text(max_nb_chars=200),
'abstract': fake.text(max_nb_chars=100),
'format': random.choice(['16mo', '8vo', '4to', 'Folio', 'Quarto']),
'binding': random.choice(['expensive', 'cheap', 'paperback', 'hardcover']),
'weight': random.uniform(100, 2000),
'cover_image': "https://example.com/cover_image.jpg"
}
@contextmanager
def get_db_session():
session = SessionLocal()
try:
yield session
except Exception as e:
session.rollback()
print(f"Database error: {e}")
raise
finally:
session.close()
def check_and_generate_random_data(generate_count=10000):
Base.metadata.create_all(bind=engine)
fake = Faker('en_GB')
batch_size = 10000
with get_db_session() as session:
data_count = session.query(Books).count()
if data_count != 0:
if input(
f"Database already contains data ({data_count}). Do you want to regenerate random data? (y/n): ").lower() != 'y':
return
if input("Delete all existing data? (y/n): ").lower() == 'y':
if data_count > 0:
session.query(Books).delete()
session.commit()
for batch_start in tqdm(range(0, generate_count, batch_size), desc="Inserting batches"):
batch_end = min(batch_start + batch_size, generate_count)
batch_data = [
generate_random_book_data(i, fake)
for i in range(batch_start, batch_end)
]
with get_db_session() as session:
session.bulk_insert_mappings(Books, batch_data)
session.commit()
#
# with app.app_context():
# check_and_generate_random_data()
# check_and_generate_random_data()
if __name__ == "__main__":
check_and_generate_random_data(int(input("Enter the number of random books to generate: ")))
print("Database initialized and random data generated.")