Load or Embed
Check for a cached embedding file and skip the API if one exists
Writing code and entering commands is only available on desktop. Open this page on a larger screen to complete this chapter.
Avoiding redundant work
Embedding a full PDF can take minutes and costs API quota. If you already embedded this PDF in a previous run, you should not pay that cost again. The main() function checks for a cache file before calling the API.
The decision is a single if/else:
cache exists?
yes → load chunks and embeddings from disk
no → extract text, chunk, embed, save to cacheThe branching logic
The load_embeddings function you wrote in the previous chapter returns None when the cache file does not exist. That makes the check simple:
cached = load_embeddings(cache_path)
if cached:
chunks, embeddings = cached
else:
# embed from scratch and saveWhen cached is not None, it holds a tuple of (chunks, embeddings). The if branch unpacks that tuple into two variables. The else branch runs the full embedding pipeline and saves the result so the next run can skip this step.
Instructions
Continue the main function. The starter code has the argument parsing from the previous chapter already filled in.
- Print
f"Loading {pdf_path}...". - Create a variable named
cached. Assign itload_embeddings(cache_path). - Add an
if cached:block. Inside it, unpackcachedintochunks, embeddingsusingchunks, embeddings = cached, then printf"Loaded cache from {cache_path}". - Add an
else:block. Inside it: createtextfromextract_text(pdf_path), createchunksfromchunk_text(text), printf"No cache found. Embedding {len(chunks)} chunks...", createembeddingsfromembed_all_chunks(client, chunks), callsave_embeddings(chunks, embeddings, cache_path), and printf"Cache saved to {cache_path}".
import json
import os
import sys
import time
import numpy as np
import pypdf
from dotenv import load_dotenv
from google import genai
from google.genai import types
def extract_text(pdf_path):
reader = pypdf.PdfReader(pdf_path)
pages = [page.extract_text() for page in reader.pages]
return "\n".join(pages)
def chunk_text(text, chunk_size=500, overlap=100):
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunks.append(text[i : i + chunk_size])
return chunks
def preview_chunks(chunks):
print(f"Total chunks: {len(chunks)}")
print(f"First chunk:\n{chunks[0]}")
def create_client():
load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=api_key)
return client
def embed_text(client, text):
result = client.models.embed_content(model="gemini-embedding-001", contents=text, config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT"))
return result.embeddings[0].values
def embed_all_chunks(client, chunks):
BATCH_SIZE = 90
embeddings = []
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i : i + BATCH_SIZE]
for chunk in batch:
embeddings.append(embed_text(client, chunk))
if i + BATCH_SIZE < len(chunks):
print("Rate limit pause — waiting 60 seconds...")
time.sleep(60)
return embeddings
def cosine_similarity(vec_a, vec_b):
dot = np.dot(vec_a, vec_b)
norm = np.linalg.norm(vec_a) * np.linalg.norm(vec_b)
return dot / norm
def search(client, query, chunks, embeddings, top_k=3):
result = client.models.embed_content(model="gemini-embedding-001", contents=query, config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY"))
query_vector = result.embeddings[0].values
scores = [(cosine_similarity(query_vector, emb), chunk) for emb, chunk in zip(embeddings, chunks)]
scores.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in scores[:top_k]]
def test_search(client, pdf_path, question):
text = extract_text(pdf_path)
chunks = chunk_text(text)
embeddings = embed_all_chunks(client, chunks)
results = search(client, question, chunks, embeddings)
for i, chunk in enumerate(results, 1):
print(f"Result {i}:\n{chunk}\n")
def build_prompt(question, context_chunks):
context = "\n\n".join(context_chunks)
prompt = f"You are a helpful assistant. Answer the question using only the context below.\nIf the answer is not in the context, say \"I don't know.\"\n\nContext:\n{context}\n\nQuestion:\n{question}"
return prompt
def generate_answer(client, prompt):
response = client.models.generate_content(model="gemini-2.5-flash", contents=prompt)
return response.text
def print_result(answer, source_chunks, show_sources=True):
print("Answer:")
print(answer)
if show_sources:
print("\nSources:")
for i, chunk in enumerate(source_chunks, 1):
print(f"Source {i}:\n{chunk}\n")
def save_embeddings(chunks, embeddings, cache_path):
data = {"chunks": chunks, "embeddings": embeddings}
with open(cache_path, "w") as f:
json.dump(data, f)
def load_embeddings(cache_path):
if not os.path.exists(cache_path):
return None
with open(cache_path) as f:
data = json.load(f)
return data["chunks"], data["embeddings"]
def main():
pdf_path = sys.argv[1]
question = sys.argv[2]
cache_path = pdf_path + ".cache.json"
client = create_client()
# Step 1: Print loading message
# Step 2: Try loading from cache
# Step 3: If cached, unpack into chunks and embeddings
# Step 4: Else, extract text, chunk, embed, and save
if __name__ == "__main__":
main()
Interactive Code Editor
Sign in to write and run code, track your progress, and unlock all chapters.
Sign In to Start Coding