jake.tech

CDK Patterns - Step Functions


We’re doing some numerical work in Lambda. There is some numerical instability and we want to make sure that we can reliably get the same results. We could use the AWS SDK to invoke Lambda in parallel:

def generate(lambda_payload):
    payload_bytes = bytes(json.dumps(lambda_payload), "UTF-8")
    response = lambda_client.invoke(
        FunctionName="arn:aws:lambda:ap-southeast-2:...",
        InvocationType="RequestResponse",
        Payload=payload_bytes,
        LogType="Tail",
    )
    if response["StatusCode"] == 200:
        payload = response["Payload"].read()
        output.append(payload)
        return payload
    else:
        raise Exception("Non-200 response")


output = []
threads = []

for idx in range(0, 180):
    x = threading.Thread(target=generate, args=(DEMO_PAYLOAD,))
    threads.append(x)


print(f"Executing {len(threads)} threads in parallel")

t1 = time()
for thread in threads:
    thread.start()

print("Running...")

for thread in threads:
    thread.join()
t2 = time()

print(output[0])
for line in output:
    if output[0] != line:
        print(line)
        raise Exception("Difference found")

print(f"Elapsed time for {len(threads)} threads: {t2-t1}")
print(f"Output length: {len(output)}")

But this is a bit brittle and we need to keep track of everything in memory. An alternative would be to string together a system with SQS, but AWS Step functions provides an easier way:

We will implement this with the CDK in TypeScript. Let’s get started.

Imports, nothing unusual here:

import * as cdk from "@aws-cdk/core";
import * as lambda from "@aws-cdk/aws-lambda";
import * as sfn from "@aws-cdk/aws-stepfunctions";
import * as tasks from "@aws-cdk/aws-stepfunctions-tasks";
import * as logs from "@aws-cdk/aws-logs";
import * as s3 from "@aws-cdk/aws-s3";
import { PythonFunction } from "@aws-cdk/aws-lambda-python";

Now we define all of our Lambda functions:

const dockerImage = lambda.DockerImageCode.fromImageAsset("..");

const splitterFunc = new PythonFunction(this, "SplitterFunc", {
    entry: "../lambda",
    index: "lambda_helpers.py",
    handler: "splitter",
    runtime: lambda.Runtime.PYTHON_3_8,
});

const reducerFunc = new PythonFunction(this, "ReducerFunc", {
    entry: "../lambda",
    index: "lambda_helpers.py",
    handler: "reducer",
    runtime: lambda.Runtime.PYTHON_3_8,
});

const dockerFunc = new lambda.DockerImageFunction(this, "clmath", {
    code: dockerImage,
    memorySize: 512,
    timeout: cdk.Duration.seconds(900),
    environment: {
    NUM_WORKERS: "1",
    HASH_ONLY: "TRUE",
    },
    currentVersionOptions: {
    removalPolicy: cdk.RemovalPolicy.RETAIN, // retain old versions
    retryAttempts: 1, // async retry attempts
    },
});

A quick detour to look at the Python code for our helper functions. In this case, our reducer just checks to make sure all results are identical.

def reducer(event, context):
    for payload in event:
        if payload!=event[0]:
            logging.info(f"Difference found: {payload} - {event[0]}")
            return False
    return True


def splitter(event, context):
    return [event]*20

Back to the CDK, now we map our functions to LambdaInvoke steps…

const reduceJob = new tasks.LambdaInvoke(this, "Check Job", {
    lambdaFunction: reducerFunc,
    inputPath: "$",
    outputPath: "$",
});

const splitJob = new tasks.LambdaInvoke(this, "Split Job", {
    lambdaFunction: splitterFunc,
    inputPath: "$",
    outputPath: "$.Payload",
});

// Append lambda output to the input as lambda_response
const calcJob = new tasks.LambdaInvoke(this, "Calc Job", {
    lambdaFunction: dockerFunc,
    outputPath: "$.Payload.SHAH_HASH",
});
const calcMap = new sfn.Map(this, "Calc Map", {
    maxConcurrency: 0, // No limit, run as many concurrently as possible
});
calcMap.iterator(calcJob);

const splitMap = new sfn.Map(this, "Split Map", {
    maxConcurrency: 0, // No limit, run as many concurrently as possible
});

Now we string everything together:

const success = new sfn.Succeed(this, "Success");
const fail = new sfn.Fail(this, "Differences found");

const choiceFunc = new sfn.Choice(this, "Choice Func")
    .when(sfn.Condition.booleanEquals("$.Payload", true), success)
    .when(sfn.Condition.booleanEquals("$.Payload", false), fail);

const subWorkflow = calcMap.next(reduceJob).next(choiceFunc);
splitMap.iterator(splitJob.next(subWorkflow));

const definition = splitMap.next(new sfn.Succeed(this, "FinalSuccess"));

const logGroup = new logs.LogGroup(this, "lg");
new sfn.StateMachine(this, "clmath-standard", {
    definition,
    timeout: cdk.Duration.seconds(3600),
    tracingEnabled: true,
    logs: {
    destination: logGroup,
    level: sfn.LogLevel.ALL,
    },
});

Finally, we invoke it using Python+Boto3:

client = boto3.client("stepfunctions")
t1 = time()
# Submit for execution
response = client.start_execution(
    stateMachineArn="arn:aws:states:ap-southeast-...",
    input=json.dumps(payload),
)
arn = response["executionArn"]

# Now poll until completion
print(f"Running... ExecutionArn: {arn}")
while True:
    response = client.describe_execution(executionArn=arn)
    status = response["status"]
    if status != "RUNNING":
        print(response)
        break
    sleep(1)
t2 = time()
print(f"Completed in {t2-t1} s")