222 lines
6.9 KiB
Python
222 lines
6.9 KiB
Python
from sqlalchemy import create_engine, select, func, Column, Integer, String, Text, Float, BigInteger
|
|
from sqlalchemy.orm import sessionmaker, Session, declarative_base
|
|
from sqlalchemy.pool import QueuePool
|
|
import sqlalchemy
|
|
from datetime import datetime
|
|
from tqdm import tqdm
|
|
from enum import Enum
|
|
from faker import Faker
|
|
import random
|
|
import os
|
|
from contextlib import contextmanager
|
|
from typing import List, Optional, Any
|
|
from dataclasses import dataclass
|
|
|
|
engine = create_engine(
|
|
# 'sqlite:///rest_server_books.db',
|
|
'postgresql://postgres:liuyanfeng66@localhost:5432/BookRestGrpcCompare',
|
|
poolclass=QueuePool,
|
|
pool_size=30,
|
|
max_overflow=50,
|
|
pool_pre_ping=True,
|
|
echo=False,
|
|
# connect_args={
|
|
# "check_same_thread": False,
|
|
# "timeout": 30 # 文件锁最长等待秒数,视情况可调
|
|
# }
|
|
)
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
@dataclass
|
|
class PaginationResult:
|
|
"""分页结果数据类"""
|
|
items: List[Any]
|
|
total: int
|
|
page: int
|
|
per_page: int
|
|
pages: int
|
|
has_prev: bool
|
|
has_next: bool
|
|
prev_num: Optional[int]
|
|
next_num: Optional[int]
|
|
|
|
|
|
class Paginator:
|
|
def __init__(self, session: Session):
|
|
self.session = session
|
|
|
|
def paginate(self, query, page: int = 1, per_page: int = 20, error_out: bool = True) -> PaginationResult:
|
|
if per_page <= 0:
|
|
return PaginationResult(
|
|
items=[],
|
|
total=0,
|
|
page=1,
|
|
per_page=0,
|
|
pages=0,
|
|
has_prev=False,
|
|
has_next=False,
|
|
prev_num=None,
|
|
next_num=None
|
|
)
|
|
|
|
page = max(1, page)
|
|
per_page = max(1, per_page)
|
|
|
|
# Get total count of items
|
|
count_query = select(func.count()).select_from(query.statement.alias())
|
|
total = self.session.scalar(count_query)
|
|
|
|
# Get total pages
|
|
pages = (total + per_page - 1) // per_page
|
|
|
|
if error_out and page > pages and total > 0:
|
|
raise ValueError(f"Page {page} out of range (1-{pages})")
|
|
|
|
offset = (page - 1) * per_page
|
|
|
|
items = query.offset(offset).limit(per_page).all()
|
|
|
|
return PaginationResult(
|
|
items=items,
|
|
total=total,
|
|
page=page,
|
|
per_page=per_page,
|
|
pages=pages,
|
|
has_prev=page > 1,
|
|
has_next=page < pages,
|
|
prev_num=page - 1 if page > 1 else None,
|
|
next_num=page + 1 if page < pages else None
|
|
)
|
|
|
|
|
|
class BaseModel(Base):
|
|
__abstract__ = True
|
|
|
|
def to_dict(self, exclude: set = None):
|
|
exclude = exclude or set()
|
|
ret = {}
|
|
for c in self.__table__.columns:
|
|
if c.name not in exclude:
|
|
curr_item = getattr(self, c.name)
|
|
if isinstance(curr_item, Enum):
|
|
ret[c.name] = curr_item.value
|
|
else:
|
|
ret[c.name] = curr_item
|
|
|
|
return ret
|
|
|
|
|
|
class Books(BaseModel):
|
|
__tablename__ = 'books'
|
|
id = Column(BigInteger, primary_key=True, autoincrement=True)
|
|
isbn: str = Column(String(20), nullable=True, index=True)
|
|
barcode: str = Column(String(50), nullable=True)
|
|
|
|
title: str = Column(Text, nullable=False)
|
|
subtitle: str = Column(Text, nullable=True)
|
|
author: str = Column(Text, nullable=False)
|
|
translator: str = Column(Text, nullable=True)
|
|
editor: str = Column(Text, nullable=True)
|
|
publisher: str = Column(String(200), nullable=True)
|
|
publication_date = Column(BigInteger, nullable=True)
|
|
edition: str = Column(String(50), nullable=True)
|
|
pages: int = Column(Integer, nullable=True)
|
|
language: str = Column(String(50), nullable=True)
|
|
|
|
category_id: int = Column(Integer, nullable=True, index=True)
|
|
subject: str = Column(String(200), nullable=True)
|
|
keywords: str = Column(Text, nullable=True)
|
|
description: str = Column(Text, nullable=True)
|
|
abstract: str = Column(Text, nullable=True)
|
|
|
|
format: str = Column(String(50), nullable=True)
|
|
binding: str = Column(String(20), nullable=True)
|
|
weight: float = Column(Float, nullable=True)
|
|
cover_image: str = Column(String(500), nullable=True)
|
|
|
|
|
|
def generate_random_book_data(book_id: int, fake: Faker) -> dict:
|
|
return {
|
|
'id': book_id,
|
|
'isbn': fake.isbn13(separator='-'),
|
|
'barcode': fake.ean(length=13),
|
|
'title': fake.sentence(nb_words=6),
|
|
'subtitle': fake.sentence(nb_words=3),
|
|
'author': fake.name(),
|
|
'translator': fake.name(),
|
|
'editor': fake.name(),
|
|
'publisher': fake.company(),
|
|
'publication_date': random.randint(-2190472676, 1754206067),
|
|
'edition': fake.word(),
|
|
'pages': random.randint(100, 1000),
|
|
'language': random.choice(['English', 'Français', 'Deutsch']),
|
|
'category_id': random.randint(1, 100),
|
|
'subject': fake.word(),
|
|
'keywords': ", ".join(fake.words(nb=5, unique=True, ext_word_list=None)),
|
|
'description': fake.text(max_nb_chars=200),
|
|
'abstract': fake.text(max_nb_chars=100),
|
|
'format': random.choice(['16mo', '8vo', '4to', 'Folio', 'Quarto']),
|
|
'binding': random.choice(['expensive', 'cheap', 'paperback', 'hardcover']),
|
|
'weight': random.uniform(100, 2000),
|
|
'cover_image': "https://example.com/cover_image.jpg"
|
|
}
|
|
|
|
|
|
@contextmanager
|
|
def get_db_session():
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
except Exception as e:
|
|
session.rollback()
|
|
print(f"Database error: {e}")
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def check_and_generate_random_data(generate_count=10000):
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
fake = Faker('en_GB')
|
|
batch_size = 10000
|
|
|
|
with get_db_session() as session:
|
|
data_count = session.query(Books).count()
|
|
if data_count != 0:
|
|
if input(
|
|
f"Database already contains data ({data_count}). Do you want to regenerate random data? (y/n): ").lower() != 'y':
|
|
return
|
|
|
|
if input("Delete all existing data? (y/n): ").lower() == 'y':
|
|
if data_count > 0:
|
|
session.query(Books).delete()
|
|
session.commit()
|
|
|
|
for batch_start in tqdm(range(0, generate_count, batch_size), desc="Inserting batches"):
|
|
batch_end = min(batch_start + batch_size, generate_count)
|
|
|
|
batch_data = [
|
|
generate_random_book_data(i, fake)
|
|
for i in range(batch_start, batch_end)
|
|
]
|
|
|
|
with get_db_session() as session:
|
|
session.bulk_insert_mappings(Books, batch_data)
|
|
session.commit()
|
|
|
|
#
|
|
# with app.app_context():
|
|
# check_and_generate_random_data()
|
|
|
|
# check_and_generate_random_data()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
check_and_generate_random_data(int(input("Enter the number of random books to generate: ")))
|
|
print("Database initialized and random data generated.")
|