// Copyright 2024 Intel Corporation. // // This software and the related documents are Intel copyrighted materials, // and your use of them is governed by the express license under which they // were provided to you ("License"). Unless the License provides otherwise, // you may not use, modify, copy, publish, distribute, disclose or transmit // this software or the related documents without Intel's prior written // permission. // // This software and the related documents are provided as is, with no express // or implied warranties, other than those that are expressly stated in the // License. /** * dla_lt_skip_counter.sv * * This module is responsible for calculating the index of the first token (pixel) arriving at the * LT input each cycle. The remaining indeces are calculated in a different pipeline. * * Initially these calculations were done using division operations, which became too slow once * runtime-configurability was added. Now this module uses a series of skip-counters with a * carry chain to calculate tensor indeces for every Nth pixel, where N = (DDR-width / pixel-width). * * This counter calculates height, width, depth, and channel indeces, as well as the "stride inner" * and "stride outer" indeces. In the most basic form, the calculations performed can be written as, * * Channels: C = index % CHANNELS * Width: W = (index / CHANNELS) % WIDTH * Height: H = (index / CHANNELS / WIDTH) % HEIGHT * Depth: D = (index / CHANNELS / WIDTH / HEIGHT) * * Where CHANNELS, WIDTH, HEIGHT, and DEPTH are the exact dimensions of the input tensor, and `index` * is the arbitrary location in the tensor (assuming DHWC memory format). The stride dimensions can * be written as follows, * * Outer stride width: stride_w = W / STRIDE_WIDTH * Outer stride height: stride_w = H / STRIDE_HEIGHT * Inner stride width: inner_w = W % STRIDE_WIDTH * Inner stride height: inner_h = H % STRIDE_HEIGHT * * These calculations above don't account for padding. To account for padding, we add (PADDING % STRIDE) * to the numerator for each stride calculation. * * This is implemented using a multi-stage skip-counter with a carry-chain as shown below, excluding * stride counters: * * ┌────────────────┐ * │CHANNELS ┼────┐ * │ │ step * │ Step = │ │ * │ N % CHANNELS│◄───┘ * │ │ * └──────────┬─────┘ * carry-in * │ * ┌───────▼────────┐ * │WIDTH ┼────┐ * │ │ step * │ Step = │ │ * │ N / CHANNELS│◄───┘ * │ % WIDTH │ * └──────────┬─────┘ * carry-in * │ * ┌────▼───────────┐ * │HEIGHT ┼────┐ * │ │ step * │ Step = │ │ * │ N / CHANNELS│◄───┘ * │ / WIDTH │ * │ % HEIGHT │ * └────────────┬───┘ * │ * │ * ┌──────▼─────────┐ * │DEPTH ┼────┐ * │ │ step * │ Step = │ │ * │ N / CHANNELS│◄───┘ * │ / WIDTH │ * │ / HEIGHT │ * └────────────────┘ * * Note: Look at `config_bitstream_generator.cpp` for details on how each of the input values to this * module are calculated. These values are passed into hardware through the configuration network. * */ `resetall `undefineall `default_nettype none module dla_lt_step_counter #( parameter ELEMENTS_PER_CYCLE, parameter DIM_BITS, parameter DEPTH_TENSOR ) ( input wire clk, input wire i_resetn, input wire i_increment, input wire [DIM_BITS-1:0] i_channel_dim, input wire [DIM_BITS-1:0] i_width_dim, input wire [DIM_BITS-1:0] i_width_overhang, input wire [DIM_BITS-1:0] i_height_overhang, input wire [DIM_BITS-1:0] i_height_dim, input wire [DIM_BITS-1:0] i_depth_dim, input wire [DIM_BITS-1:0] i_channel_step, input wire [DIM_BITS-1:0] i_width_stride, input wire [DIM_BITS-1:0] i_width_step, input wire [DIM_BITS-1:0] i_stride_w_count, input wire [DIM_BITS-1:0] i_width_stride_step, input wire [DIM_BITS-1:0] i_width_inner_step, input wire [DIM_BITS-1:0] i_height_stride, input wire [DIM_BITS-1:0] i_height_step, input wire [DIM_BITS-1:0] i_stride_h_count, input wire [DIM_BITS-1:0] i_height_stride_step, input wire [DIM_BITS-1:0] i_height_inner_step, input wire [DIM_BITS-1:0] i_depth_step, input wire [DIM_BITS-1:0] i_pad_w, input wire [DIM_BITS-1:0] i_pad_h, input wire [DIM_BITS-1:0] i_continue_count_cond, input wire [DIM_BITS-1:0] i_overhang_end_w, input wire [DIM_BITS-1:0] i_w_nstrides, input wire [DIM_BITS-1:0] i_h_nstrides, output logic [DIM_BITS-1:0] o_channel, output logic [DIM_BITS-1:0] o_width, output logic [DIM_BITS-1:0] o_width_stride, output logic [DIM_BITS-1:0] o_width_inner, output logic [DIM_BITS-1:0] o_height, output logic [DIM_BITS-1:0] o_height_stride, output logic [DIM_BITS-1:0] o_height_inner, output logic [DIM_BITS-1:0] o_depth, output logic o_valid ); localparam shortint N_STAGES = 3; logic [N_STAGES-1:0] stage_cnt; assign o_valid = stage_cnt[2]; logic [DIM_BITS-1:0] channel, channel_carry; logic [DIM_BITS-1:0] width, width_stride, width_inner, width_inner_carry, width_carry; logic [DIM_BITS-1:0] height, height_stride, height_inner, height_inner_carry, height_carry; logic [DIM_BITS-1:0] depth; logic [DIM_BITS-1:0] channel_reg [N_STAGES-1:0]; logic [DIM_BITS-1:0] channel_carry_reg [N_STAGES-1:0]; logic [DIM_BITS-1:0] width_reg [N_STAGES-2:0]; logic [DIM_BITS-1:0] width_carry_reg [N_STAGES-2:0]; logic [DIM_BITS-1:0] width_inner_reg [N_STAGES-2:0]; logic [DIM_BITS-1:0] width_inner_carry_reg [N_STAGES-2:0]; logic [DIM_BITS-1:0] height_reg; logic [DIM_BITS-1:0] height_carry_reg; logic [DIM_BITS-1:0] height_inner_reg; logic [DIM_BITS-1:0] height_inner_carry_reg; logic [DIM_BITS-1:0] width_stride_reg; logic [DIM_BITS-1:0] width_stride_carry_reg; int next_stride; always_comb begin /** CHANNELS (0) **/ channel_carry = 0; if ((channel_reg[N_STAGES-1] + i_channel_step) < i_channel_dim) begin channel = channel_reg[N_STAGES-1] + i_channel_step; end else begin channel = channel_reg[N_STAGES-1] + i_channel_step - i_channel_dim; channel_carry = 1; end /** WIDTH (1) **/ width_carry = 0; if ((width_reg[N_STAGES-2] + i_width_step + channel_carry_reg[N_STAGES-1] ) < i_width_dim) begin width = width_reg[N_STAGES-2] + i_width_step + channel_carry_reg[N_STAGES-1]; end else begin width = width_reg[N_STAGES-2] + i_width_step + channel_carry_reg[N_STAGES-1] - i_width_dim; width_carry = 1; end /** INNER WIDTH STRIDE (2) **/ // Move down a stage.... width_inner_carry = 0; if (width_carry_reg[N_STAGES-2] & i_width_overhang != 0 & ~i_continue_count_cond) begin if (i_pad_w + i_width_inner_step + channel_carry_reg[N_STAGES-2] - (i_width_stride - width_inner_reg[N_STAGES-3] - i_overhang_end_w) < 0) begin width_inner = i_pad_w; end else if (i_pad_w + i_width_inner_step + channel_carry_reg[N_STAGES-2] - (i_width_stride - width_inner_reg[N_STAGES-3] - i_overhang_end_w) < i_width_stride) begin width_inner = i_pad_w + i_width_inner_step + channel_carry_reg[N_STAGES-2] - (i_width_stride - width_inner_reg[N_STAGES-3] - i_overhang_end_w); width_inner_carry = 1; end else begin width_inner = i_pad_w + i_width_inner_step + channel_carry_reg[N_STAGES-2] - (i_width_stride - width_inner_reg[N_STAGES-3] - i_overhang_end_w); width_inner_carry = 2; end end else if ((width_inner_reg[N_STAGES-3] + i_width_inner_step + channel_carry_reg[N_STAGES-2]) < i_width_stride) begin width_inner = width_inner_reg[N_STAGES-3] + i_width_inner_step + channel_carry_reg[N_STAGES-2]; end else begin width_inner = width_inner_reg[N_STAGES-3] + channel_carry_reg[N_STAGES-2] + i_width_inner_step - i_width_stride; width_inner_carry = 1; end /** WIDTH STRIDE (3) **/ if (width_carry_reg[N_STAGES-3] & i_width_overhang != 0 & i_continue_count_cond) begin next_stride = width_inner_carry_reg[N_STAGES-3] + i_width_stride_step + 1 - (i_w_nstrides - o_width_stride); if (next_stride < i_w_nstrides) begin width_stride = next_stride; end else begin width_stride = o_width_stride + width_inner_carry_reg[N_STAGES-3] - i_w_nstrides; end end else if ((o_width_stride + i_width_stride_step + width_inner_carry_reg[N_STAGES-3]) < i_w_nstrides) begin width_stride = o_width_stride + i_width_stride_step + width_inner_carry_reg[N_STAGES-3]; end else begin width_stride = o_width_stride + width_inner_carry_reg[N_STAGES-3] + i_width_stride_step - i_w_nstrides; end /** HEIGHT (2) **/ height_carry = 0; if ((height_reg + i_height_step + width_carry_reg[N_STAGES-2]) < i_height_dim) begin height = height_reg + i_height_step + width_carry_reg[N_STAGES-2]; end else begin height = height_reg + i_height_step + width_carry_reg[N_STAGES-2] - i_height_dim; height_carry = 1; end /** INNER HEIGHT STRIDE (2) **/ height_inner_carry = 0; if ((height_inner_reg + i_height_inner_step + width_carry_reg[N_STAGES-2]) < i_height_stride) begin height_inner = height_inner_reg + i_height_inner_step + width_carry_reg[N_STAGES-2]; end else begin height_inner = height_inner_reg + i_height_inner_step + width_carry_reg[N_STAGES-2] - i_height_stride; height_inner_carry = 1; end /** HEIGHT STRIDE (3) **/ if ((o_height_stride + i_height_stride_step + height_inner_carry_reg) < i_h_nstrides) begin height_stride = o_height_stride + i_height_stride_step + height_inner_carry_reg; end else begin height_stride = o_height_stride + i_height_stride_step + height_inner_carry_reg - i_h_nstrides; end /** DEPTH (3) **/ if (DEPTH_TENSOR) begin depth = o_depth + i_depth_step + height_carry_reg; // Shouldn't overflow - then we've reached the end. end else begin depth = 0; end end always_ff @( posedge clk ) begin if (i_increment) begin if (~stage_cnt[N_STAGES-1]) stage_cnt <= (stage_cnt << 1) | 1; // 0 channel_reg[N_STAGES-1] <= channel; channel_reg[N_STAGES-2] <= channel_reg[N_STAGES-1]; channel_reg[N_STAGES-3] <= channel_reg[N_STAGES-2]; channel_carry_reg[N_STAGES-1] <= channel_carry; channel_carry_reg[N_STAGES-2] <= channel_carry_reg[N_STAGES-1]; channel_carry_reg[N_STAGES-3] <= channel_carry_reg[N_STAGES-2]; // 1 if (stage_cnt[0]) begin width_reg[N_STAGES-2] <= width; width_reg[N_STAGES-3] <= width_reg[N_STAGES-2]; width_carry_reg[N_STAGES-2] <= width_carry; width_carry_reg[N_STAGES-3] <= width_carry_reg[N_STAGES-2]; end // 2 width_inner_reg[N_STAGES-3] <= i_width_overhang; height_inner_reg <= i_height_overhang; if (stage_cnt[1]) begin width_inner_carry_reg[N_STAGES-3] <= width_inner_carry; width_inner_reg[N_STAGES-3] <= width_inner; height_reg <= height; height_carry_reg <= height_carry; height_inner_reg <= height_inner; height_inner_carry_reg <= height_inner_carry; end // preload these values which contain padding information. o_height_inner <= height_inner_reg; o_width_inner <= width_inner_reg[0]; // 3 if (stage_cnt[2]) begin o_channel <= channel_reg[0]; o_width <= width_reg[0]; o_height <= height_reg; o_depth <= depth; o_height_stride <= height_stride; o_width_stride <= width_stride; end end else begin o_channel <= o_channel; o_width <= o_width; o_height <= o_height; o_depth <= o_depth; o_height_inner <= o_height_inner; o_width_inner <= o_width_inner; o_height_stride <= o_height_stride; o_width_stride <= o_width_stride; end if (~i_resetn) begin o_channel <= '0; o_width <= '0; o_height <= '0; o_depth <= '0; o_height_inner <= '0; o_width_inner <= '0; o_height_stride <= '0; o_width_stride <= '0; stage_cnt <= '0; height_reg <= '0; height_carry_reg <= '0; height_inner_reg <= '0; height_inner_carry_reg <= '0; width_stride_reg <= '{default: '0}; width_stride_carry_reg <= '{default: '0}; channel_reg <= '{default: '0}; channel_carry_reg <= '{default: '0}; width_reg <= '{default: '0}; width_carry_reg <= '{default: '0}; width_inner_reg <= '{default: '0}; width_inner_carry_reg <= '{default: '0}; end end endmodule