Cache Embeddings to Disk
Save embeddings to a JSON file so the app skips re-embedding on subsequent runs
Writing code and entering commands is only available on desktop. Open this page on a larger screen to complete this chapter.
The problem with re-embedding every run
Embedding 1,731 chunks takes several minutes and costs API quota. If you ask a second question about the same PDF, you should not pay that cost again.
The solution: save the chunks and their vectors to a JSON file after the first run. On the next run, load from that file instead of calling the API.
Run 1: embed → save to cache.json → answer question
Run 2: load from cache.json → answer question (instant)The cache format
A single JSON object with two keys:
{
"chunks": ["chunk one...", "chunk two...", ...],
"embeddings": [[0.12, -0.04, ...], [0.33, 0.91, ...], ...]
}The json module is part of Python's standard library — no installation required.
Instructions
Write two functions. The starter code provides both signatures.
- In
save_embeddings, create a variable nameddata. Assign it a dict with two keys:"chunks"set tochunks, and"embeddings"set toembeddings. - Open
cache_pathfor writing usingwith open(cache_path, "w") as f:. - Inside the
withblock, calljson.dump(data, f). - In
load_embeddings, add anifstatement:if not os.path.exists(cache_path):. Inside it, returnNone. - Open
cache_pathfor reading usingwith open(cache_path) as f:. - Inside the
withblock, create a variable nameddata. Assign itjson.load(f). - Return
data["chunks"], data["embeddings"].
import json
import os
import time
import numpy as np
import pypdf
from dotenv import load_dotenv
from google import genai
from google.genai import types
def extract_text(pdf_path):
reader = pypdf.PdfReader(pdf_path)
pages = [page.extract_text() for page in reader.pages]
return "\n".join(pages)
def chunk_text(text, chunk_size=500, overlap=100):
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunks.append(text[i : i + chunk_size])
return chunks
def preview_chunks(chunks):
print(f"Total chunks: {len(chunks)}")
print(f"First chunk:\n{chunks[0]}")
def create_client():
load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")
client = genai.Client(api_key=api_key)
return client
def embed_text(client, text):
result = client.models.embed_content(model="gemini-embedding-001", contents=text, config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT"))
return result.embeddings[0].values
def embed_all_chunks(client, chunks):
BATCH_SIZE = 90
embeddings = []
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i : i + BATCH_SIZE]
for chunk in batch:
embeddings.append(embed_text(client, chunk))
if i + BATCH_SIZE < len(chunks):
print("Rate limit pause — waiting 60 seconds...")
time.sleep(60)
return embeddings
def cosine_similarity(vec_a, vec_b):
dot = np.dot(vec_a, vec_b)
norm = np.linalg.norm(vec_a) * np.linalg.norm(vec_b)
return dot / norm
def search(client, query, chunks, embeddings, top_k=3):
result = client.models.embed_content(model="gemini-embedding-001", contents=query, config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY"))
query_vector = result.embeddings[0].values
scores = [(cosine_similarity(query_vector, emb), chunk) for emb, chunk in zip(embeddings, chunks)]
scores.sort(key=lambda x: x[0], reverse=True)
return [chunk for _, chunk in scores[:top_k]]
def test_search(client, pdf_path, question):
text = extract_text(pdf_path)
chunks = chunk_text(text)
embeddings = embed_all_chunks(client, chunks)
results = search(client, question, chunks, embeddings)
for i, chunk in enumerate(results, 1):
print(f"Result {i}:\n{chunk}\n")
def build_prompt(question, context_chunks):
context = "\n\n".join(context_chunks)
prompt = f"You are a helpful assistant. Answer the question using only the context below.\nIf the answer is not in the context, say \"I don't know.\"\n\nContext:\n{context}\n\nQuestion:\n{question}"
return prompt
def generate_answer(client, prompt):
response = client.models.generate_content(model="gemini-2.5-flash", contents=prompt)
return response.text
def print_result(answer, source_chunks, show_sources=True):
print("Answer:")
print(answer)
if show_sources:
print("\nSources:")
for i, chunk in enumerate(source_chunks, 1):
print(f"Source {i}:\n{chunk}\n")
def save_embeddings(chunks, embeddings, cache_path):
# Step 1: Create data dict with "chunks" and "embeddings" keys
# Step 2: Open cache_path for writing
# Step 3: json.dump(data, f)
def load_embeddings(cache_path):
# Step 4: If file doesn't exist, return None
# Step 5: Open cache_path for reading
# Step 6: Load JSON into data
# Step 7: Return data["chunks"], data["embeddings"]
Interactive Code Editor
Sign in to write and run code, track your progress, and unlock all chapters.
Sign In to Start Coding