Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept left offsets when applying position encodings #1374

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ set(SOURCES
src/ops/mul.cc
src/ops/multinomial.cc
src/ops/multinomial_cpu.cc
src/ops/position_encodings_add.cc
src/ops/position_encodings_add_cpu.cc
src/ops/quantize.cc
src/ops/quantize_cpu.cc
src/ops/relu.cc
Expand Down Expand Up @@ -511,6 +513,7 @@ if (WITH_CUDA)
src/ops/layer_norm_gpu.cu
src/ops/mean_gpu.cu
src/ops/multinomial_gpu.cu
src/ops/position_encodings_add_gpu.cu
src/ops/rms_norm_gpu.cu
src/ops/rotary_gpu.cu
src/ops/softmax_gpu.cu
Expand Down
8 changes: 6 additions & 2 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ namespace ctranslate2 {
// Base class for position encoders.
class PositionEncoder : public Layer {
public:
void operator()(StorageView& input, dim_t index = 0);
void operator()(const StorageView& input, StorageView& output, dim_t index = 0);
void operator()(const StorageView& input,
StorageView& output,
dim_t step = 0,
const StorageView* offsets = nullptr);
protected:
virtual const StorageView& get_position_encoding(dim_t max_time) = 0;
private:
ops::PositionEncodingsAdd _add_op;
};

// Concrete position encoder loading encoding vectors from the model.
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
#include "median_filter.h"
#include "rotary.h"
#include "alibi_add.h"
#include "position_encodings_add.h"
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/position_encodings_add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class PositionEncodingsAdd : public Op {
public:
void operator()(const StorageView& input,
const StorageView& encodings,
StorageView& output,
const StorageView* offsets = nullptr,
const dim_t step = 0) const;

private:
template <Device D, typename T>
void compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& encodings,
StorageView& output) const;
};

}
}
34 changes: 6 additions & 28 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "ctranslate2/ops/activation.h"
#include "cpu/backend.h"
#include "dispatch.h"

namespace ctranslate2 {
namespace layers {
Expand Down Expand Up @@ -148,34 +147,13 @@ namespace ctranslate2 {
}


void PositionEncoder::operator()(StorageView& input, dim_t index) {
void PositionEncoder::operator()(const StorageView& input,
StorageView& output,
dim_t step,
const StorageView* offsets) {
const dim_t time = input.dim(1);
const dim_t depth = input.dim(-1);
const dim_t max_time = std::max(time, index + 1);
const StorageView& encodings = get_position_encoding(max_time);
const dim_t num_encodings = encodings.dim(0);

if (max_time > num_encodings)
throw std::runtime_error("No position encodings are defined for positions >= "
+ std::to_string(num_encodings)
+ ", but got position "
+ std::to_string(max_time - 1));
if (depth != encodings.dim(1))
throw std::invalid_argument("Shape mismatch: position encodings have depth "
+ std::to_string(encodings.dim(1))
+ ", but the input has depth "
+ std::to_string(depth));

DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(),
primitives<D>::add_batch_broadcast(encodings.data<T>() + index * depth,
input.data<T>(),
time * depth,
input.size()));
}

void PositionEncoder::operator()(const StorageView& input, StorageView& output, dim_t index) {
output = input;
operator()(output, index);
const StorageView& encodings = get_position_encoding(step + time);
_add_op(input, encodings, output, offsets, step);
}


Expand Down
4 changes: 2 additions & 2 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ namespace ctranslate2 {
if (_embeddings_scale)
ops::Mul()(input, *_embeddings_scale, input);
if (_position_encoder)
(*_position_encoder)(input);
(*_position_encoder)(input, input);
if (_layernorm_embedding)
(*_layernorm_embedding)(input, input);

Expand Down Expand Up @@ -460,7 +460,7 @@ namespace ctranslate2 {
if (layer_in.rank() == 2)
layer_in.expand_dims(1);
if (_position_encoder)
(*_position_encoder)(layer_in, std::max(step, dim_t(0)));
(*_position_encoder)(layer_in, layer_in, std::max(step, dim_t(0)));
if (_layernorm_embedding)
(*_layernorm_embedding)(layer_in, layer_in);

Expand Down
2 changes: 1 addition & 1 deletion src/layers/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace ctranslate2 {
_gelu(output, output);

_transpose(output, input);
_position_embedding(input);
_position_embedding(input, input);

for (const auto& layer : _layers) {
(*layer)(input, nullptr, output);
Expand Down
50 changes: 50 additions & 0 deletions src/ops/position_encodings_add.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "ctranslate2/ops/position_encodings_add.h"

#include "dispatch.h"

namespace ctranslate2 {
namespace ops {

void PositionEncodingsAdd::operator()(const StorageView& input,
const StorageView& encodings,
StorageView& output,
const StorageView* offsets,
const dim_t step) const {
PROFILE("PositionEncodingsAdd");

const dim_t time = input.dim(1);
const dim_t depth = input.dim(2);
const dim_t max_time = time + step;

if (max_time > encodings.dim(0))
throw std::runtime_error("No position encodings are defined for positions >= "
+ std::to_string(encodings.dim(0))
+ ", but got position "
+ std::to_string(max_time - 1));

if (depth != encodings.dim(1))
throw std::invalid_argument("Shape mismatch: position encodings have depth "
+ std::to_string(encodings.dim(1))
+ ", but the input has depth "
+ std::to_string(depth));

output.resize_as(input);

if (offsets) {
DEVICE_AND_FLOAT_DISPATCH(
"PositionEncodingsAdd", input.device(), input.dtype(),
(compute<D, T>(step, offsets, input, encodings, output)));

} else {
DEVICE_AND_FLOAT_DISPATCH(
"PositionEncodingsAdd", input.device(), input.dtype(),
(primitives<D>::add_batch_broadcast(encodings.data<T>() + step * depth,
input.data<T>(),
output.data<T>(),
time * depth,
input.size())));
}
}

}
}
48 changes: 48 additions & 0 deletions src/ops/position_encodings_add_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "ctranslate2/ops/position_encodings_add.h"

#include "cpu/parallel.h"

namespace ctranslate2 {
namespace ops {

template <Device D, typename T>
void PositionEncodingsAdd::compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& encodings,
StorageView& output) const {
const dim_t batch_size = input.dim(0);
const dim_t time = input.dim(1);
const dim_t depth = input.dim(2);

cpu::parallel_for(0, batch_size * time, 1, [&](const dim_t begin, const dim_t end) {
for (dim_t i = begin; i < end; ++i) {
const dim_t b = i / time;
const dim_t t = i % time;

const dim_t offset = offsets ? offsets->at<int32_t>(b) : 0;
const dim_t encoding_offset = t - offset + step;

if (encoding_offset < 0)
continue;

primitives<Device::CPU>::add(encodings.index<float>({encoding_offset, 0}),
input.index<float>({b, t, 0}),
output.index<float>({b, t, 0}),
depth);
}
});
}

#define DECLARE_IMPL(T) \
template void \
PositionEncodingsAdd::compute<Device::CPU, T>(const dim_t, \
const StorageView*, \
const StorageView&, \
const StorageView&, \
StorageView&) const;

DECLARE_IMPL(float)

}
}
73 changes: 73 additions & 0 deletions src/ops/position_encodings_add_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include "ctranslate2/ops/position_encodings_add.h"

#include "type_dispatch.h"
#include "cuda/helpers.h"

namespace ctranslate2 {
namespace ops {

template <typename T, typename AddFunc>
__global__ void position_encodings_add_kernel(const T* input,
const T* encodings,
T* output,
const int32_t* offsets,
cuda::index_t step,
cuda::index_t max_time,
cuda::index_t depth,
const AddFunc& add_func) {
const cuda::index_t batch = blockIdx.x / max_time;
const cuda::index_t time = blockIdx.x % max_time;

const int32_t offset = offsets ? offsets[batch] : 0;
const int32_t encoding_offset = time - offset + step;

if (encoding_offset < 0)
return;

input += blockIdx.x * depth;
output += blockIdx.x * depth;
encodings += encoding_offset * depth;

for (cuda::index_t i = threadIdx.x; i < depth; i += blockDim.x) {
output[i] = add_func(input[i], encodings[i]);
}
}

template <Device D, typename T>
void PositionEncodingsAdd::compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& encodings,
StorageView& output) const {
const dim_t batch_size = input.dim(0);
const dim_t time = input.dim(1);
const dim_t depth = input.dim(2);

const dim_t blocks = std::min(batch_size * time, cuda::max_blocks);
const dim_t threads = std::min(depth, cuda::max_threads);

position_encodings_add_kernel<<<blocks, threads, 0, cuda::get_cuda_stream()>>>(
cuda::device_cast(input.data<T>()),
cuda::device_cast(encodings.data<T>()),
cuda::device_cast(output.data<T>()),
offsets ? offsets->data<int32_t>() : nullptr,
step,
time,
depth,
cuda::plus<cuda::device_type<T>>());
}

#define DECLARE_IMPL(T) \
template void \
PositionEncodingsAdd::compute<Device::CUDA, T>(const dim_t, \
const StorageView*, \
const StorageView&, \
const StorageView&, \
StorageView&) const;

DECLARE_IMPL(float)
DECLARE_IMPL(float16_t)
DECLARE_IMPL(bfloat16_t)

}
}
34 changes: 32 additions & 2 deletions tests/layers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,36 @@ TEST(LayerTest, PadderIgnore) {
expect_storage_eq(x, original);
}

TEST_P(LayerDeviceFPTest, PositionEncoderOffset) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;

layers::SinusoidalPositionEncoder position_encoder(4, dtype, device);

StorageView offsets({2}, std::vector<int32_t>{3, 1}, device);
dim_t step = 5;

StorageView expected_encodings(dtype, device);

{
StorageView zero({2, 5, 4}, 0.f, device);
StorageView encodings(dtype, device);
position_encoder(zero.to(dtype), encodings);

StorageView position_ids({2, 1}, std::vector<int32_t>{2, 4}, device);
ops::Gather(/*axis=*/1, /*batch_dims=*/1)(encodings, position_ids, expected_encodings);
}

{
StorageView zero({2, 1, 4}, 0.f, device);
StorageView encodings(dtype, device);
position_encoder(zero.to(dtype), encodings, step, &offsets);

expect_storage_eq(encodings.to_float32(), expected_encodings.to_float32(), error);
}
}

TEST(LayerTest, PositionEncoderNoSharedState) {
// Test case for issue: http://forum.opennmt.net/t/ctranslate2-c-api-returns-strange-results-when-initializing-2-models/3208
layers::SinusoidalPositionEncoder position_encoder_1(4);
Expand All @@ -233,7 +263,7 @@ TEST(LayerTest, PositionEncoderNoSharedState) {
{1, 1, 4}, std::vector<float>{0.1, -2.3, 0.5, 1.2});
StorageView expected(
{1, 1, 4}, std::vector<float>{0.941471, -2.2999, 1.0403, 2.2});
position_encoder_1(input);
position_encoder_1(input, input);
expect_storage_eq(input, expected, 1e-5);
}

Expand All @@ -242,7 +272,7 @@ TEST(LayerTest, PositionEncoderNoSharedState) {
{1, 1, 6}, std::vector<float>{-0.2, -1.3, 0.1, -0.6, 2.0, 1.1});
StorageView expected(
{1, 1, 6}, std::vector<float>{0.641471, -1.29, 0.1001, -0.0596977, 2.99995, 2.1});
position_encoder_2(input);
position_encoder_2(input, input);
expect_storage_eq(input, expected, 1e-5);
}
}
Expand Down