Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/amberscript/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2952,15 +2952,16 @@ Result Parser::ParseBufferInitializerFill(Buffer* buffer,
}

auto fmt = buffer->GetFormat();
bool is_double_data = fmt->IsFloat32() || fmt->IsFloat64();
bool is_float_data =
fmt->IsFloat16() || fmt->IsFloat32() || fmt->IsFloat64();

// Inflate the size because our items are multi-dimensional.
size_in_items = size_in_items * fmt->InputNeededPerElement();

std::vector<Value> values;
values.resize(size_in_items);
for (size_t i = 0; i < size_in_items; ++i) {
if (is_double_data) {
if (is_float_data) {
values[i].SetDoubleValue(token->AsDouble());
} else {
values[i].SetIntValue(token->AsUint64());
Expand Down Expand Up @@ -2994,7 +2995,8 @@ Result Parser::ParseBufferInitializerSeries(Buffer* buffer,
auto n = type->AsNumber();
FormatMode mode = n->GetFormatMode();
uint32_t num_bits = n->NumBits();
if (type::Type::IsFloat32(mode, num_bits) ||
if (type::Type::IsFloat16(mode, num_bits) ||
type::Type::IsFloat32(mode, num_bits) ||
type::Type::IsFloat64(mode, num_bits)) {
counter.SetDoubleValue(token->AsDouble());
} else {
Expand All @@ -3020,7 +3022,8 @@ Result Parser::ParseBufferInitializerSeries(Buffer* buffer,
std::vector<Value> values;
values.resize(size_in_items);
for (size_t i = 0; i < size_in_items; ++i) {
if (type::Type::IsFloat32(mode, num_bits) ||
if (type::Type::IsFloat16(mode, num_bits) ||
type::Type::IsFloat32(mode, num_bits) ||
type::Type::IsFloat64(mode, num_bits)) {
double value = counter.AsDouble();
values[i].SetDoubleValue(value);
Expand Down
61 changes: 61 additions & 0 deletions src/amberscript/parser_buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "gtest/gtest.h"
#include "src/amberscript/parser.h"
#include "src/float16_helper.h"

namespace amber {
namespace amberscript {
Expand Down Expand Up @@ -400,6 +401,34 @@ TEST_F(AmberScriptParserTest, BufferFillFloat) {
}
}

TEST_F(AmberScriptParserTest, BufferFillFloat16) {
std::string in = "BUFFER my_buffer DATA_TYPE float16 SIZE 5 FILL 5.5";

Parser parser;
Result r = parser.Parse(in);
ASSERT_TRUE(r.IsSuccess()) << r.Error();

auto script = parser.GetScript();
const auto& buffers = script->GetBuffers();
ASSERT_EQ(1U, buffers.size());

ASSERT_TRUE(buffers[0] != nullptr);

auto* buffer = buffers[0].get();
EXPECT_EQ("my_buffer", buffer->GetName());
EXPECT_EQ(Format::Layout::kStd430, buffer->GetFormat()->GetLayout());
EXPECT_EQ(5U, buffer->ElementCount());
EXPECT_EQ(5U, buffer->ValueCount());
EXPECT_EQ(5U * sizeof(uint16_t), buffer->GetSizeInBytes());

const auto* data = buffer->GetValues<uint16_t>();
for (size_t i = 0; i < buffer->ValueCount(); ++i) {
EXPECT_FLOAT_EQ(
5.5f, float16::HexFloatToFloat(
reinterpret_cast<const uint8_t*>(&data[i]), 16));
}
}

TEST_F(AmberScriptParserTest, BufferSeries) {
std::string in =
"BUFFER my_buffer DATA_TYPE uint8 SIZE 5 SERIES_FROM 2 INC_BY 1";
Expand Down Expand Up @@ -461,6 +490,38 @@ TEST_F(AmberScriptParserTest, BufferSeriesFloat) {
}
}

TEST_F(AmberScriptParserTest, BufferSeriesFloat16) {
std::string in =
"BUFFER my_buffer DATA_TYPE float16 SIZE 5 SERIES_FROM 2.5 INC_BY "
"0.5";

Parser parser;
Result r = parser.Parse(in);
ASSERT_TRUE(r.IsSuccess()) << r.Error();

auto script = parser.GetScript();
const auto& buffers = script->GetBuffers();
ASSERT_EQ(1U, buffers.size());

ASSERT_TRUE(buffers[0] != nullptr);

auto* buffer = buffers[0].get();
EXPECT_EQ("my_buffer", buffer->GetName());
EXPECT_EQ(Format::Layout::kStd430, buffer->GetFormat()->GetLayout());
EXPECT_EQ(5U, buffer->ElementCount());
EXPECT_EQ(5U, buffer->ValueCount());
EXPECT_EQ(5U * sizeof(uint16_t), buffer->GetSizeInBytes());

std::vector<float> results = {2.5f, 3.0f, 3.5f, 4.0f, 4.5f};
const auto* data = buffer->GetValues<uint16_t>();
ASSERT_EQ(results.size(), buffer->ValueCount());
for (size_t i = 0; i < results.size(); ++i) {
EXPECT_FLOAT_EQ(
results[i], float16::HexFloatToFloat(
reinterpret_cast<const uint8_t*>(&data[i]), 16));
}
}

TEST_F(AmberScriptParserTest, BufferMultipleBuffers) {
std::string in = R"(
BUFFER color_buffer DATA_TYPE uint8 SIZE 5 FILL 5
Expand Down
6 changes: 6 additions & 0 deletions src/format.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ class Format {
type::Type::IsUint64(type_->AsNumber()->GetFormatMode(),
type_->AsNumber()->NumBits());
}
/// Returns true if all components of this format are a 16 bit float.
bool IsFloat16() const {
return type_->IsNumber() &&
type::Type::IsFloat16(type_->AsNumber()->GetFormatMode(),
type_->AsNumber()->NumBits());
}
/// Returns true if all components of this format are a 32 bit float.
bool IsFloat32() const {
return type_->IsNumber() &&
Expand Down