3. Loops and Conditions in JAX#
3.1. Introduction to Loops in JAX#
Loops are pivotal in repetitive tasks, such as iterating over sequences or performing computations iteratively. JAX provides various loop constructs, including jax.lax.fori_loop
, jax.lax.while_loop
, and jax.lax.scan
, enabling fine-grained control over looping mechanisms. In this lecture, we’ll delve into these constructs and demonstrate their usage through practical examples.
import jax.numpy as jnp
import jax
3.1.1. Using jax.lax.fori_loop
#
jax.lax.fori_loop
is a loop construct in JAX that allows for iterating a fixed number of times. It’s akin to Python’s for loop but optimized for computation within JAX’s framework. This function is useful for tasks that require repeated computations or transformations over a predetermined range of iterations.
The jax.lax.fori_loop
function facilitates a loop with a predetermined number of iterations, similar to Python’s for
loop.
Let’s demonstrate its usage with a simple example
# Define a Python loop to sum the squares of numbers from `start` to `end`
def sum_squares(start, end):
total_sum = 0
for i in range(10):
total_sum += i ** 2
return total_sum
sum_squares(1, 10)
285
Now, let’s try to re-write the above function using jax.lax.fori_loop.
In jax.lax.fori_loop
, the arguments are passed in the following manner:
Start Value: This argument specifies the initial value of the loop variable.
End Value: This argument specifies the upper bound for the loop variable. The loop will iterate until the loop variable reaches this value.
Body Function: This is a function that defines the body of the loop. It takes two arguments: the loop variable and the carry value. The loop variable represents the current iteration index, while the carry value represents any intermediate state that needs to be maintained across loop iterations.
Initial Carry Value: This argument specifies the initial value of the carry variable, which is passed to the body function in each iteration.
# Rewrite the loop using jax.lax.fori_loop
def sum_squares_jax(start, end):
def body_fun(i, total):
return total + i ** 2
return jax.lax.fori_loop(start, # lower
end, # upper
body_fun, # body_fun
0) # init_val (of total)
sum_squares_jax(0, 10)
Array(285, dtype=int32, weak_type=True)
In this example, we define a function sum_squares_jax
that computes the sum of squares from a given start value to an end value using jax.lax.fori_loop
. The body_fun
function squares each number from the loop index i
and accumulates the result in the total
variable. Finally, the loop is executed with the specified start and end values, and the result is returned.
3.1.2. Using jax.lax.while_loop
#
jax.lax.while_loop
is another looping construct provided by JAX, enabling iterative execution until a termination condition is met. It resembles Python’s while loop but is designed to seamlessly integrate with JAX’s computational graph and automatic differentiation capabilities. while_loop is suitable for situations where the number of iterations is not known beforehand and depends on runtime conditions.
Let’s illustrate its usage with an example
# Define a Python while loop to compute the factorial of `n`
def factorial(n):
result = 1
i = 1
while i <= n:
result *= i
i += 1
return result
factorial(6)
720
Now, let’s try to re-write the above function using jax.lax.while_loop
In jax.lax.while_loop
, the arguments are passed as follows:
Loop Condition Function: This function defines the termination condition of the loop. It takes the current loop state as its argument and returns a boolean value indicating whether the loop should continue (
True
) or terminate (False
).Loop Body Function: This function defines the body of the loop. It takes the current loop state as its argument and returns the updated loop state for the next iteration.
Initial Loop State: This argument specifies the initial state of the loop, which is passed to both the loop condition and loop body functions.
# Rewrite the loop using jax.lax.while_loop
def factorial_jax(n):
def condition(state):
i, result = state
return i <= n
def body(state):
i, result = state
return (i + 1, result * i)
_, result = jax.lax.while_loop(condition, # cond_fun
body, # body_fun
(1, 1)) # init_value (i=1, result=1)
return result
factorial_jax(6)
Array(720, dtype=int32, weak_type=True)
In this example, we define a function factorial_jax
that computes the factorial of a number using jax.lax.while_loop
. The condition
function checks if the loop variable i
is less than or equal to n
, while the body
function updates the loop state by incrementing i
and accumulating the factorial in the result
variable. The loop continues until the condition is False
, and the final result is returned.
Since in the final result we get the value of (i, result)
, we ignore the first value
and return the result.
3.1.3. Using jax.lax.scan
#
jax.lax.scan
is a function in JAX for performing cumulative computations over a sequence of inputs. It’s similar to Python’s accumulate function but optimized for efficient execution within JAX’s framework. scan
is commonly used for tasks such as computing cumulative sums, products, or applying a function iteratively over a sequence while accumulating results. It’s a powerful tool for implementing recurrent neural networks, sequential models, or any computation involving cumulative operations.
jax.lax.scan
is generalized version of handling loops in JAX and can handle complex looping constructs.
Let’s see the following example
# Define a Python function to compute cumulative sums of a list
def cumulative_sums(nums):
cumulative_sums = []
total = 0
for num in nums:
total += num
cumulative_sums.append(total)
return cumulative_sums
nums = [1, 2, 3, 4, 5]
cumulative_sums(nums)
[1, 3, 6, 10, 15]
Now, let’s try to re-write the above function using jax.lax.scan
In jax.lax.scan
, the arguments are passed as follows:
Body Function: This function defines the computation to be performed at each step of the loop. It takes two arguments: the loop variable (or current input element) and the carry variable (or accumulated state), and returns a tuple containing the updated loop variable and the updated carry variable.
Initial Carry Value: This argument specifies the initial value of the carry variable, which is passed as the initial state to the loop.
Sequence: This argument specifies the input sequence over which the loop iterates.
# Rewrite the computation using jax.lax.scan
def cumulative_sums_jax(nums):
def body(total, num):
return total + num, total + num
total, cumulative_sums_array = jax.lax.scan(body, # f
0, # init
nums) # xs
return cumulative_sums_array
cumulative_sums_jax(jnp.array(nums))
Array([ 1, 3, 6, 10, 15], dtype=int32)
In this example, we define a function cumulative_sums_jax
that computes cumulative sums using jax.lax.scan
. The body
function computes the sum of the current element and the carry variable, updating both the loop variable and the carry variable. The loop iterates over the input sequence, accumulating the sums at each step, and the final result is returned.
3.2. Conditional Execution with JAX#
3.2.1. Introduction to jax.lax.cond
#
jax.lax.cond
is a conditional execution function provided by JAX, allowing users to perform different computations based on specified conditions. This enables dynamic control flow within JAX computations, facilitating conditional branching similar to Python’s if
statement. We’ll explore the usage of jax.lax.cond
through practical examples.
# Define a Python function to check if a number is positive or negative
def check_sign(x):
if x > 0:
return 1
else:
return -1
# Execute the Python function with a sample input
print("Sign of 5 (Python):", check_sign(5))
print("Sign of -10 (Python):", check_sign(-10))
Sign of 5 (Python): 1
Sign of -10 (Python): -1
Let’s re-write the same using jax.lax.cond.
In jax.lax.cond
, the arguments are passed as follows:
Predicate: This is a boolean scalar indicating the condition to be evaluated. If the predicate is
True
, thetrue_fun
will be executed; otherwise, thefalse_fun
will be executed.True Function: This function defines the computation to be performed if the predicate is
True
. It takes no arguments and returns the result of the computation when the condition is satisfied.False Function: This function defines the computation to be performed if the predicate is
False
. It takes no arguments and returns the result of the computation when the condition is not satisfied.
# Rewrite the function using jax.cond
def check_sign_jax(x):
def positive_branch(x):
return 1
def negative_branch(x):
return -1
return jax.lax.cond(x > 0, # pred
positive_branch, # true_fn
negative_branch, # false_fn
x) # operands
# Execute the JAX function with the same input
print("Sign of 5 (JAX cond):", check_sign_jax(5))
print("Sign of -10 (JAX cond):", check_sign_jax(-10))
Sign of 5 (JAX cond): 1
Sign of -10 (JAX cond): -1
In this example, we define a function check_sign_jax
that checks if a number is positive or negative using jax.lax.cond
. Depending on whether the input x
is greater than 0 (positive) or not (negative), the corresponding true or false function is executed, and the result is returned.
3.3. Why do we need jax.lax
?#
While JAX provides high-level abstractions for numerical computing, leveraging low-level constructs from jax.lax
can lead to significant speedups, especially when compared to traditional Python for loops.
Moreover, in-order to use the JAX’s JIT, sometime its necessary to leverage low-level constructs.
3.3.1. Importance of Performance Optimization#
Efficient computation is essential for tackling complex problems in machine learning, scientific computing, and other domains. Performance optimization techniques, such as minimizing computational overhead and maximizing hardware utilization, are critical for achieving faster execution times and scaling to larger datasets or models.
Let’s take the example from fori_loop
section and
form a compiled jit functions
computation_jit_lax = jax.jit(sum_squares_jax)
computation_jit_python = jax.jit(sum_squares, static_argnums=(0, 1))
# Compare execution times
x = 10000
For computation_jit_python
, JAX requires that the static_argnums
parameter be provided because range
itself is a dynamic operation. By specifying the index of the argument that corresponds to the range
’s upper bound in static_argnums
, JAX can treat the range
as static during compilation, optimizing the loop accordingly. This helps avoid unnecessary recompilation of the loop body for different loop bounds, leading to improved performance.
%time result_jit_lax = computation_jit_lax(0, x)
CPU times: user 16 ms, sys: 12 µs, total: 16 ms
Wall time: 15.6 ms
%time result_jit_lax = computation_jit_lax(0, x)
CPU times: user 156 µs, sys: 0 ns, total: 156 µs
Wall time: 119 µs
Running the computation_jit_lax
takes a bit more time in the first call because of the compilation overhead
%time result_jit_python = computation_jit_python(0, x)
CPU times: user 8.51 ms, sys: 0 ns, total: 8.51 ms
Wall time: 8.24 ms
%time result_jit_python = computation_jit_python(0, x)
CPU times: user 99 µs, sys: 25 µs, total: 124 µs
Wall time: 95.8 µs
Notice the time difference in 2nd calls of both computation_jit_lax
and computation_jit_python
.
The function computation_jit_lax
has clear advantanges over computation_jit_python
because of two major reasons:
Very fast because of
jax.lax
and low-level optimizations done by jax.computation_jit_python
has usedstatic_argnums
injax.jit
which means for every new values ofstart
andend
,computation_jit_python
will re-compile and evaluate the results which will make it even slower unlikecomputation_jit_lax
. Once compiled,computation_jit_lax
will call the same function irrespective of the value ofstart
andend
.