Skip to content

srush/Tensor-Puzzles

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tensor Puzzles

When learning a tensor programming language like PyTorch or Numpy it is tempting to rely on the standard library (or more honestly StackOverflow) to find a magic function for everything. But in practice, the tensor language is extremely expressive, and you can do most things from first principles and clever use of broadcasting.

This is a collection of 21 tensor puzzles. Like chess puzzles these are not meant to simulate the complexity of a real program, but to practice in a simplified environment. Each puzzle asks you to reimplement one function in the NumPy standard library without magic.

I recommend running in Colab. Click here and copy the notebook to get start.

Open In Colab

If you are interested, there is also a youtube walkthrough of the puzzles

Watch the video

!pip install -qqq torchtyping hypothesis pytest git+https://github.com/danoneata/chalk@srush-patch-1
!wget -q https://github.com/srush/Tensor-Puzzles/raw/main/lib.py
from lib import draw_examples, make_test, run_test
import torch
import numpy as np
from torchtyping import TensorType as TT
tensor = torch.tensor

Rules

  1. These puzzles are about broadcasting. Know this rule.

  1. Each puzzle needs to be solved in 1 line (<80 columns) of code.

  2. You are allowed @, arithmetic, comparison, shape, any indexing (e.g. a[:j], a[:, None], a[arange(10)]), and previous puzzle functions.

  3. You are not allowed anything else. No view, sum, take, squeeze, tensor.

  4. You can start with these two functions:

def arange(i: int):
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))

draw_examples("arange", [{"" : arange(i)} for i in [5, 3, 9]])

svg

# Example of broadcasting.
examples = [(arange(4), arange(5)[:, None]) ,
            (arange(3)[:, None], arange(2))]
draw_examples("broadcast", [{"a": a, "b":b, "ret": a + b} for a, b in examples])

svg

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

# In diagrams, orange is positive/True, where is zero/False, and blue is negative.

examples = [(tensor([False]), tensor([10]), tensor([0])),
            (tensor([False, True]), tensor([1, 1]), tensor([-10, 0])),
            (tensor([False, True]), tensor([1]), tensor([-10, 0])),
            (tensor([[False, True], [True, False]]), tensor([1]), tensor([-10, 0])),
            (tensor([[False, True], [True, False]]), tensor([[0], [10]]), tensor([-10, 0])),
           ]
draw_examples("where", [{"q": q, "a":a, "b":b, "ret": where(q, a, b)} for q, a, b in examples])

svg

Puzzle 1 - ones

Compute ones - the vector of all ones.

def ones_spec(out):
    for i in range(len(out)):
        out[i] = 1
        
def ones(i: int) -> TT["i"]:
    raise NotImplementedError

test_ones = make_test("one", ones, ones_spec, add_sizes=["i"])

svg

# run_test(test_ones)

Puzzle 2 - sum

Compute sum - the sum of a vector.

def sum_spec(a, out):
    out[0] = 0
    for i in range(len(a)):
        out[0] += a[i]
        
def sum(a: TT["i"]) -> TT[1]:
    raise NotImplementedError


test_sum = make_test("sum", sum, sum_spec)

svg

# run_test(test_sum)

Puzzle 3 - outer

Compute outer - the outer product of two vectors.

def outer_spec(a, b, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            out[i][j] = a[i] * b[j]
            
def outer(a: TT["i"], b: TT["j"]) -> TT["i", "j"]:
    raise NotImplementedError
    
test_outer = make_test("outer", outer, outer_spec)

svg

# run_test(test_outer)

Puzzle 4 - diag

Compute diag - the diagonal vector of a square matrix.

def diag_spec(a, out):
    for i in range(len(a)):
        out[i] = a[i][i]
        
def diag(a: TT["i", "i"]) -> TT["i"]:
    raise NotImplementedError


test_diag = make_test("diag", diag, diag_spec)

svg

# run_test(test_diag)

Puzzle 5 - eye

Compute eye - the identity matrix.

def eye_spec(out):
    for i in range(len(out)):
        out[i][i] = 1
        
def eye(j: int) -> TT["j", "j"]:
    raise NotImplementedError
    
test_eye = make_test("eye", eye, eye_spec, add_sizes=["j"])

svg

# run_test(test_eye)

Puzzle 6 - triu

Compute triu - the upper triangular matrix.

def triu_spec(out):
    for i in range(len(out)):
        for j in range(len(out)):
            if i <= j:
                out[i][j] = 1
            else:
                out[i][j] = 0
                
def triu(j: int) -> TT["j", "j"]:
    raise NotImplementedError


test_triu = make_test("triu", triu, triu_spec, add_sizes=["j"])

svg

# run_test(test_triu)

Puzzle 7 - cumsum

Compute cumsum - the cumulative sum.

def cumsum_spec(a, out):
    total = 0
    for i in range(len(out)):
        out[i] = total + a[i]
        total += a[i]

def cumsum(a: TT["i"]) -> TT["i"]:
    raise NotImplementedError

test_cumsum = make_test("cumsum", cumsum, cumsum_spec)

svg

# run_test(test_cumsum)

Puzzle 8 - diff

Compute diff - the running difference.

def diff_spec(a, out):
    out[0] = a[0]
    for i in range(1, len(out)):
        out[i] = a[i] - a[i - 1]

def diff(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError

test_diff = make_test("diff", diff, diff_spec, add_sizes=["i"])

svg

# run_test(test_diff)

Puzzle 9 - vstack

Compute vstack - the matrix of two vectors

def vstack_spec(a, b, out):
    for i in range(len(out[0])):
        out[0][i] = a[i]
        out[1][i] = b[i]

def vstack(a: TT["i"], b: TT["i"]) -> TT[2, "i"]:
    raise NotImplementedError


test_vstack = make_test("vstack", vstack, vstack_spec)

svg

# run_test(test_vstack)

Puzzle 10 - roll

Compute roll - the vector shifted 1 circular position.

def roll_spec(a, out):
    for i in range(len(out)):
        if i + 1 < len(out):
            out[i] = a[i + 1]
        else:
            out[i] = a[i + 1 - len(out)]
            
def roll(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError


test_roll = make_test("roll", roll, roll_spec, add_sizes=["i"])

svg

# run_test(test_roll)

Puzzle 11 - flip

Compute flip - the reversed vector

def flip_spec(a, out):
    for i in range(len(out)):
        out[i] = a[len(out) - i - 1]
        
def flip(a: TT["i"], i: int) -> TT["i"]:
    raise NotImplementedError


test_flip = make_test("flip", flip, flip_spec, add_sizes=["i"])

svg

# run_test(test_flip)

Puzzle 12 - compress

Compute compress - keep only masked entries (left-aligned).

def compress_spec(g, v, out):
    j = 0
    for i in range(len(g)):
        if g[i]:
            out[j] = v[i]
            j += 1
            
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
    raise NotImplementedError


test_compress = make_test("compress", compress, compress_spec, add_sizes=["i"])

svg

# run_test(test_compress)

Puzzle 13 - pad_to

Compute pad_to - eliminate or add 0s to change size of vector.

def pad_to_spec(a, out):
    for i in range(min(len(out), len(a))):
        out[i] = a[i]


def pad_to(a: TT["i"], i: int, j: int) -> TT["j"]:
    raise NotImplementedError


test_pad_to = make_test("pad_to", pad_to, pad_to_spec, add_sizes=["i", "j"])

svg

# run_test(test_pad_to)

Puzzle 14 - sequence_mask

Compute sequence_mask - pad out to length per batch.

def sequence_mask_spec(values, length, out):
    for i in range(len(out)):
        for j in range(len(out[0])):
            if j < length[i]:
                out[i][j] = values[i][j]
            else:
                out[i][j] = 0
    
def sequence_mask(values: TT["i", "j"], length: TT["i", int]) -> TT["i", "j"]:
    raise NotImplementedError


def constraint_set_length(d):
    d["length"] = d["length"] % d["values"].shape[1]
    return d


test_sequence = make_test("sequence_mask",
    sequence_mask, sequence_mask_spec, constraint=constraint_set_length
)

svg

# run_test(test_sequence)

Puzzle 15 - bincount

Compute bincount - count number of times an entry was seen.

def bincount_spec(a, out):
    for i in range(len(a)):
        out[a[i]] += 1
        
def bincount(a: TT["i"], j: int) -> TT["j"]:
    raise NotImplementedError


def constraint_set_max(d):
    d["a"] = d["a"] % d["return"].shape[0]
    return d


test_bincount = make_test("bincount",
    bincount, bincount_spec, add_sizes=["j"], constraint=constraint_set_max
)

svg

# run_test(test_bincount)

Puzzle 16 - scatter_add

Compute scatter_add - add together values that link to the same location.

def scatter_add_spec(values, link, out):
    for j in range(len(values)):
        out[link[j]] += values[j]
        
def scatter_add(values: TT["i"], link: TT["i"], j: int) -> TT["j"]:
    raise NotImplementedError


def constraint_set_max(d):
    d["link"] = d["link"] % d["return"].shape[0]
    return d


test_scatter_add = make_test("scatter_add",
    scatter_add, scatter_add_spec, add_sizes=["j"], constraint=constraint_set_max
)

svg

# run_test(test_scatter_add)

Puzzle 17 - flatten

Compute flatten

def flatten_spec(a, out):
    k = 0
    for i in range(len(a)):
        for j in range(len(a[0])):
            out[k] = a[i][j]
            k += 1

def flatten(a: TT["i", "j"], i:int, j:int) -> TT["i * j"]:
    raise NotImplementedError

test_flatten = make_test("flatten", flatten, flatten_spec, add_sizes=["i", "j"])

svg

# run_test(test_flatten)

Puzzle 18 - linspace

Compute linspace

def linspace_spec(i, j, out):
    for k in range(len(out)):
        out[k] = float(i + (j - i) * k / max(1, len(out) - 1))

def linspace(i: TT[1], j: TT[1], n: int) -> TT["n", float]:
    raise NotImplementedError

test_linspace = make_test("linspace", linspace, linspace_spec, add_sizes=["n"])

svg

# run_test(test_linspace)

Puzzle 19 - heaviside

Compute heaviside

def heaviside_spec(a, b, out):
    for k in range(len(out)):
        if a[k] == 0:
            out[k] = b[k]
        else:
            out[k] = int(a[k] > 0)

def heaviside(a: TT["i"], b: TT["i"]) -> TT["i"]:
    raise NotImplementedError

test_heaviside = make_test("heaviside", heaviside, heaviside_spec)

svg

# run_test(test_heaviside)

Puzzle 20 - repeat (1d)

Compute repeat

def repeat_spec(a, d, out):
    for i in range(d[0]):
        for k in range(len(a)):
            out[i][k] = a[k]

def constraint_set(d):
    d["d"][0] = d["return"].shape[0]
    return d

            
def repeat(a: TT["i"], d: TT[1]) -> TT["d", "i"]:
    raise NotImplementedError

test_repeat = make_test("repeat", repeat, repeat_spec, constraint=constraint_set)

svg

Puzzle 21 - bucketize

Compute bucketize

def bucketize_spec(v, boundaries, out):
    for i, val in enumerate(v):
        out[i] = 0
        for j in range(len(boundaries)-1):
            if val >= boundaries[j]:
                out[i] = j + 1
        if val >= boundaries[-1]:
            out[i] = len(boundaries)


def constraint_set(d):
    d["boundaries"] = np.abs(d["boundaries"]).cumsum()
    return d

            
def bucketize(v: TT["i"], boundaries: TT["j"]) -> TT["i"]:
    raise NotImplementedError

test_bucketize = make_test("bucketize", bucketize, bucketize_spec,
                           constraint=constraint_set)

svg

Speed Run Mode!

What is the smallest you can make each of these?

import inspect
fns = (ones, sum, outer, diag, eye, triu, cumsum, diff, vstack, roll, flip,
       compress, pad_to, sequence_mask, bincount, scatter_add)

for fn in fns:
    lines = [l for l in inspect.getsource(fn).split("\n") if not l.strip().startswith("#")]
    
    if len(lines) > 3:
        print(fn.__name__, len(lines[2]), "(more than 1 line)")
    else:
        print(fn.__name__, len(lines[1]))
ones 29
sum 29
outer 29
diag 29
eye 29
triu 29
cumsum 29
diff 29
vstack 29
roll 29
flip 29
compress 29
pad_to 29
sequence_mask 29
bincount 29
scatter_add 29